Beispiel #1
0
def build_params_proto(params):
    """Build a TensorForestParams proto out of the V4ForestHParams object."""
    proto = _params_proto.TensorForestParams()
    proto.num_trees = params.num_trees
    proto.max_nodes = params.max_nodes
    proto.is_regression = params.regression
    proto.num_outputs = params.num_classes
    proto.num_features = params.num_features

    proto.leaf_type = params.leaf_model_type
    proto.stats_type = params.stats_model_type
    proto.collection_type = _params_proto.COLLECTION_BASIC
    proto.pruning_type.type = params.pruning_type
    proto.finish_type.type = params.finish_type

    proto.inequality_test_type = params.split_type

    proto.drop_final_class = False
    proto.collate_examples = params.collate_examples
    proto.checkpoint_stats = params.checkpoint_stats
    proto.use_running_stats_method = params.use_running_stats_method
    proto.initialize_average_splits = params.initialize_average_splits
    proto.inference_tree_paths = params.inference_tree_paths

    parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples,
                                    params.prune_every_samples)
    parse_number_or_string_to_proto(proto.finish_type.check_every_steps,
                                    params.early_finish_check_every_samples)
    parse_number_or_string_to_proto(proto.split_after_samples,
                                    params.split_after_samples)
    parse_number_or_string_to_proto(proto.num_splits_to_consider,
                                    params.num_splits_to_consider)

    proto.dominate_fraction.constant_value = params.dominate_fraction

    if params.param_file:
        with open(params.param_file) as f:
            text_format.Merge(f.read(), proto)

    return proto
Beispiel #2
0
def build_params_proto(params):
  """Build a TensorForestParams proto out of the V4ForestHParams object."""
  proto = _params_proto.TensorForestParams()
  proto.num_trees = params.num_trees
  proto.max_nodes = params.max_nodes
  proto.is_regression = params.regression
  proto.num_outputs = params.num_classes
  proto.num_features = params.num_features

  proto.leaf_type = params.v4_leaf_model_type
  proto.stats_type = params.v4_stats_model_type
  proto.collection_type = params.v4_split_collection_type
  proto.pruning_type.type = params.v4_pruning_type
  proto.finish_type.type = params.v4_finish_type

  proto.inequality_test_type = params.v4_split_type

  proto.drop_final_class = False
  proto.collate_examples = params.v4_collate_examples
  proto.checkpoint_stats = params.v4_checkpoint_stats
  proto.use_running_stats_method = params.v4_use_running_stats_method
  proto.initialize_average_splits = params.v4_initialize_average_splits

  if params.v4_prune_every_samples:
    text_format.Merge(params.v4_prune_every_samples,
                      proto.pruning_type.prune_every_samples)
  else:
    # Pruning half-way through split_after_samples seems like a decent default,
    # making it easy to select the number being pruned with v4_pruning_type
    # while not paying the cost of pruning too often.  Note that this only holds
    # if not using a depth-dependent split_after_samples.
    if params.v4_split_after_samples:
      logging.error(
          'If using depth-dependent split_after_samples and also pruning, '
          'need to set v4_prune_every_samples')
    proto.pruning_type.prune_every_samples.constant_value = (
        params.split_after_samples / 2)

  if params.v4_finish_check_every_samples:
    text_format.Merge(params.v4_finish_check_every_samples,
                      proto.finish_type.check_every_steps)
  else:
    # Checking for finish every quarter through split_after_samples seems
    # like a decent default. We don't want to incur the checking cost too often,
    # but (at least for hoeffding) it's lower than the cost of pruning so
    # we can do it a little more frequently.
    proto.finish_type.check_every_steps.constant_value = int(
        params.split_after_samples / 4)

  if params.v4_split_after_samples:
    text_format.Merge(params.v4_split_after_samples, proto.split_after_samples)
  else:
    proto.split_after_samples.constant_value = params.split_after_samples

  if params.v4_num_splits_to_consider:
    text_format.Merge(params.v4_num_splits_to_consider,
                      proto.num_splits_to_consider)
  else:
    proto.num_splits_to_consider.constant_value = params.num_splits_to_consider

  proto.dominate_fraction.constant_value = params.dominate_fraction
  proto.min_split_samples.constant_value = params.split_after_samples

  if params.v4_param_file:
    with open(params.v4_param_file) as f:
      text_format.Merge(f.read(), proto)

  return proto