def __init__(self, params, tree_num, training, tree_config='', tree_stat=''): if (not hasattr(params, 'params_proto') or not isinstance(params.params_proto, _params_proto.TensorForestParams)): params.params_proto = build_params_proto(params) params.serialized_params_proto = params.params_proto.SerializeToString() self.stats = None if training: # TODO(gilberth): Manually shard this to be able to fit it on # multiple machines. self.stats = stats_ops.fertile_stats_variable( params, tree_stat, self.get_tree_name('stats', tree_num)) self.tree = model_ops.tree_variable(params, tree_config, self.stats, self.get_tree_name('tree', tree_num))
def __init__(self, params, tree_num, training): if (not hasattr(params, 'params_proto') or not isinstance( params.params_proto, _params_proto.TensorForestParams)): params.params_proto = build_params_proto(params) params.serialized_params_proto = params.params_proto.SerializeToString( ) self.stats = None if training: # TODO(gilberth): Manually shard this to be able to fit it on # multiple machines. self.stats = stats_ops.fertile_stats_variable( params, '', self.get_tree_name('stats', tree_num)) self.tree = model_ops.tree_variable( params, '', self.stats, self.get_tree_name('tree', tree_num))