def __init__(self, stamp_token, gradient_shape, hessian_shape, name=None, container=None): """Creates a stats accumulator and returns a handle to it. Args: stamp_token: An int64, initial value to use for the stamp token. gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. name: A name for the stats accumulator variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: self._name = name self._resource_handle = self._create_resource() self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self.initializer, is_initialized_op) self._saveable = StatsAccumulatorSaveable(self.resource_handle, self.initializer, self._is_scalar, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
def tree_variable(params, tree_config, stats_handle, name, container=None): r"""Creates a tree model and returns a handle to it. Args: params: A TensorForestParams object. tree_config: A `Tensor` of type `string`. Serialized proto of the tree. stats_handle: Resource handle to the stats object. name: A name for the variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the tree. """ with ops.name_scope(name, "TreeVariable") as name: resource_handle = gen_model_ops.decision_tree_resource_handle_op( container, shared_name=name, name=name) create_op = gen_model_ops.create_tree_variable( resource_handle, tree_config, params=params.serialized_params_proto) is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle) # Adds the variable to the savable list. saveable = TreeVariableSavable(params, resource_handle, stats_handle, create_op, resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, container=None): r"""Creates a tree ensemble model and returns a handle to it. Args: stamp_token: The initial stamp token value for the ensemble resource. tree_ensemble_config: A `Tensor` of type `string`. Serialized proto of the tree ensemble. name: A name for the ensemble variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( container, shared_name=name, name=name) create_op = gen_model_ops.create_tree_ensemble_variable( resource_handle, stamp_token, tree_ensemble_config) is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( resource_handle) # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: params: A TensorForestParams object. stats_config: A `Tensor` of type `string`. Serialized proto of the stats. name: A name for the variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( container, shared_name=name, name=name) create_op = gen_stats_ops.create_fertile_stats_variable( resource_handle, stats_config, params=params.serialized_params_proto) is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( resource_handle) # Adds the variable to the savable list. saveable = FertileStatsVariableSavable(params, resource_handle, create_op, resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def tree_variable(params, tree_config, stats_handle, name, container=None): r"""Creates a tree model and returns a handle to it. Args: params: A TensorForestParams object. tree_config: A `Tensor` of type `string`. Serialized proto of the tree. stats_handle: Resource handle to the stats object. name: A name for the variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the tree. """ with ops.name_scope(name, "TreeVariable") as name: resource_handle = gen_model_ops.decision_tree_resource_handle_op( container, name, name=name) create_op = gen_model_ops.create_tree_variable( resource_handle, tree_config, params=params.serialized_params_proto) is_initialized_op = gen_model_ops.tree_is_initialized_op( resource_handle) # Adds the variable to the savable list. saveable = TreeVariableSavable(params, resource_handle, stats_handle, create_op, "tree_checkpoint_{0}".format(name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def __init__(self, init_stamp_token, epsilon, num_quantiles, max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. Args: init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ self._epsilon = epsilon name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: self._quantile_accumulator_handle = ( gen_quantile_ops.quantile_stream_resource_handle_op( container=container, shared_name=name, name=name)) self._create_op = gen_quantile_ops.create_quantile_accumulator( self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) resources.register_resource(self._quantile_accumulator_handle, self._create_op, is_initialized_op) self._make_savable(name)
def __init__(self, initial_value=None, name=None, trainable=True, collections=None, dtype=None, shape=None): """Creates a variable. Args: initial_value: A `Tensor` or Python object convertible to a `Tensor` representing the initial value of this variable. name: The name of this variable. Automatically uniquified. trainable: Whether the global read of this variable will be used for training. collections: Additional collections to which the `read` operation for this variable is to be added. Defaults to []. dtype: The type of this variable. Can be omitted if it can be deduced from the initial_value. If different from the type of the initial value it will be cast to this type. shape: The shape of this variable. Only specify if there is no initial value but shape inference is desired. """ if initial_value is not None: initial_value = ops.convert_to_tensor(initial_value) if dtype is None: assert initial_value is not None, ("Trying to create a resource variable " "with no dtype or initial value. At" " least one of these must be set.") dtype = initial_value.dtype elif initial_value is not None: initial_value = math_ops.cast(initial_value, dtype) if shape is None: if initial_value is not None: shape = initial_value.get_shape().as_proto() else: shape = tensor_shape.unknown_shape() else: shape = tensor_shape.as_shape(shape) self._dtype = dtype with ops.name_scope(name, "Variable", [initial_value]) as name: self._handle = var_handle_op(shared_name=name, name=name, dtype=dtype, shape=shape) with ops.name_scope("IsInitialized"): self._is_initialized_op = var_is_initialized_op(self._handle) if initial_value is not None: with ops.name_scope("Create"): self._initialize_op = create_variable_op(self._handle, initial_value) resources.register_resource(self._handle, self._initialize_op, self._is_initialized_op) with ops.name_scope("Read"): self._value = read_variable_op(self._handle, dtype=self._dtype) _register_dense_variable_read( self._value, trainable=trainable, collections=collections)
def __init__(self, type_name, name, container, config, resource_handle_func, create_op_func, is_initialized_op_func, serialize_op_func, deserialize_op_func): with ops.name_scope(name, type_name) as name: self._resource_handle = resource_handle_func(container, shared_name=name, name=name) self._is_initialized_op = is_initialized_op_func(self._resource_handle) tensor = serialize_op_func(self._resource_handle) self._create_op = create_op_func(self._resource_handle, config) # slice_spec is useful for saving a slice from a variable. # It's not meaningful the tree variable. So we just pass an empty # value. slice_spec = '' specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)] super(TreeVariableSaveable, self).__init__(self._resource_handle, specs, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) resources.register_resource(self._resource_handle, self._create_op, self._is_initialized_op) self._deserialize_op_func = deserialize_op_func
def __init__(self, init_stamp_token, epsilon, num_quantiles, name=None, container=None): """Creates a QuantileAccumulator object. Args: init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ self._epsilon = epsilon name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: self._quantile_accumulator_handle = ( gen_quantile_ops.quantile_stream_resource_handle_op( container=container, shared_name=name, name=name)) self._create_op = gen_quantile_ops.create_quantile_accumulator( self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) resources.register_resource(self._quantile_accumulator_handle, self._create_op, is_initialized_op) self._make_savable(name)
def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: params: A TensorForestParams object. stats_config: A `Tensor` of type `string`. Serialized proto of the stats. name: A name for the variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: fertile_stats_var = FertileStatsVariable(params, stats_config, name, container) resource_handle = fertile_stats_var.resource_handle create_op = fertile_stats_var.initializer is_initialized_op = fertile_stats_var.is_initialized() # Adds the variable to the savable list. saveable = ( fertile_stats_var._gather_saveables_for_checkpoint()[ # pylint: disable=protected-access "fertile_stats_variable"](name=resource_handle.name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): with ops.name_scope(name, 'TreeEnsemble') as name: self._resource_handle = (gen_boosted_trees_ops. boosted_trees_ensemble_resource_handle_op( container='', shared_name=name, name=name)) create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble( self.resource_handle, stamp_token, tree_ensemble_serialized=serialized_proto) is_initialized_op = ( gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( self._resource_handle)) # Adds the variable to the savable list. if not is_local: saveable = _TreeEnsembleSavable(self.resource_handle, create_op, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(self.resource_handle, create_op, is_initialized_op, is_shared=not is_local)
def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: params: A TensorForestParams object. stats_config: A `Tensor` of type `string`. Serialized proto of the stats. name: A name for the variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( container, name, name=name) create_op = gen_stats_ops.create_fertile_stats_variable( resource_handle, stats_config, params=params.serialized_params_proto) is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( resource_handle) # Adds the variable to the savable list. saveable = FertileStatsVariableSavable(params, resource_handle, create_op, "stats_checkpoint_{0}".format(name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle
def saveable(self): return self._saveable super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) resources.register_resource(self._resource_handle, create_op, is_initialized_op) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
def create_resource(self, name, eps, max_elements, num_streams=1): quantile_accumulator_handle = resource_handle_op( container="", shared_name=name, name=name) create_op = boosted_trees_ops.create_quantile_stream_resource( quantile_accumulator_handle, epsilon=eps, max_elements=max_elements, num_streams=num_streams) is_initialized_op = resource_initialized(quantile_accumulator_handle) resources.register_resource(quantile_accumulator_handle, create_op, is_initialized_op) return quantile_accumulator_handle
def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) init_op_fn = lambda: factory(resource, **kwargs) init_op = init_op_fn() if not context.executing_eagerly(): # TODO(apassos): Consider doing this instead. # ops.get_default_session().run(init_op) ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) # TODO(nickfelt): expose an actual op for this is_initialized_op = constant_op.constant(True) resources.register_resource(resource, init_op, is_initialized_op) return SummaryWriter(resource, init_op_fn)
def test_run_feeds_iter_calls_resources_init(self): with tf.Graph().as_default() as g: in0, _, _ = self._build_inference_graph() handle = test_ops.stub_resource_handle_op(container='a', shared_name='b') resources.register_resource( handle=handle, create_op=test_ops.resource_create_op(handle), is_initialized_op=test_ops.resource_initialized_op(handle)) for _ in learn.graph_actions.run_feeds_iter({'in0': in0}, feed_dicts=[{}]): self.assertTrue(test_ops.resource_initialized_op(handle).eval())
def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): with ops.name_scope(name, 'TreeEnsemble') as name: self._name = name self._resource_handle = self._create_resource() self._init_op = self._initialize() is_initialized_op = self.is_initialized() # Adds the variable to the savable list. if not is_local: saveable = _TreeEnsembleSavable(self.resource_handle, create_op, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource( self.resource_handle, create_op, is_initialized_op, is_shared=not is_local)
def __init__(self, epsilon, num_streams, num_quantiles, name=None, max_elements=None): with ops.name_scope(name, 'QuantileAccumulator') as name: self._eps = epsilon self._num_streams = num_streams self._num_quantiles = num_quantiles self._resource_handle = quantile_resource_handle_op( container='', shared_name=name, name=name) self._create_op = create_quantile_stream_resource( self._resource_handle, epsilon, num_streams) is_initialized_op = is_quantile_resource_initialized( self._resource_handle) resources.register_resource(self._resource_handle, self._create_op, is_initialized_op) self._make_saveable(name)
def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): self._stamp_token = stamp_token self._serialized_proto = serialized_proto self._is_local = is_local with ops.name_scope(name, 'TreeEnsemble') as name: self._name = name self._resource_handle = self.create_resource() self._init_op = self.initialize() is_initialized_op = self.is_initialized() # Adds the variable to the savable list. if not is_local: self._saveable = _TreeEnsembleSavable( self.resource_handle, self.initializer, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) resources.register_resource( self.resource_handle, self.initializer, is_initialized_op, is_shared=not is_local)
def __init__(self, stamp_token, gradient_shape, hessian_shape, name=None, container=None): """Creates a stats accumulator and returns a handle to it. Args: stamp_token: An int64, initial value to use for the stamp token. gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. name: A name for the stats accumulator variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ self._stamp_token = stamp_token self._gradient_shape = gradient_shape self._hessian_shape = hessian_shape self._container = container if (gradient_shape == tensor_shape.scalar() and hessian_shape == tensor_shape.scalar()): self._is_scalar = True else: self._is_scalar = False if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: self._name = name self._resource_handle = self._create_resource() self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self.initializer, is_initialized_op) self._saveable = StatsAccumulatorSaveable( self.resource_handle, self.initializer, self._is_scalar, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
def __init__(self, init_stamp_token, epsilon, num_quantiles, max_elements=None, name=None, container=None, generate_quantiles=False): """Creates a QuantileAccumulator object. Args: init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` generate_quantiles: Generate quantiles instead of approximate boundaries. If true, exactly `num_quantiles` will be produced in the final summary. """ self._epsilon = epsilon self._generate_quantiles = generate_quantiles name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: self._quantile_accumulator_handle = ( gen_quantile_ops.quantile_stream_resource_handle_op( container=container, shared_name=name, name=name)) self._create_op = gen_quantile_ops.create_quantile_accumulator( self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, max_elements=max_elements, num_quantiles=num_quantiles, generate_quantiles=generate_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) resources.register_resource(self._quantile_accumulator_handle, self._create_op, is_initialized_op) self._make_savable(name)
def __init__(self, init_stamp_token, epsilon, num_quantiles, max_elements=None, name=None, container=None, generate_quantiles=False): """Creates a QuantileAccumulator object. Args: init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` generate_quantiles: Generate quantiles instead of approximate boundaries. If true, exactly `num_quantiles` will be produced in the final summary. """ self._init_stamp_token = init_stamp_token self._epsilon = epsilon self._num_quantiles = num_quantiles self._max_elements = max_elements self._container = container self._generate_quantiles = generate_quantiles super(QuantileAccumulator, self).__init__() name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: self._name = name self._resource_handle = self._create_resource() self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self._init_op, is_initialized_op) self._saveable = QuantileAccumulatorSaveable(self.resource_handle, self._init_op, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
def __init__(self, epsilon, num_streams, num_quantiles, name=None, max_elements=None): self._eps = epsilon self._num_streams = num_streams self._num_quantiles = num_quantiles super(QuantileAccumulator, self).__init__() with ops.name_scope(name, 'QuantileAccumulator') as name: self._name = name self._resource_handle = self.create_resource() self._init_op = self.initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self._init_op, is_initialized_op) self._saveable = QuantileAccumulatorSaveable( self.resource_handle, self._init_op, self._num_streams, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): with ops.name_scope(name, 'TreeEnsemble') as name: self._resource_handle = ( gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( container='', shared_name=name, name=name)) create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble( self.resource_handle, stamp_token, tree_ensemble_serialized=serialized_proto) is_initialized_op = ( gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( self._resource_handle)) # Adds the variable to the savable list. if not is_local: saveable = _TreeEnsembleSavable(self.resource_handle, create_op, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource( self.resource_handle, create_op, is_initialized_op, is_shared=not is_local)
def __init__(self, epsilon, num_streams, num_quantiles, name=None, max_elements=None): self._eps = epsilon self._num_streams = num_streams self._num_quantiles = num_quantiles super(QuantileAccumulator, self).__init__() with ops.name_scope(name, 'QuantileAccumulator') as name: self._name = name self._resource_handle = self._create_resource() self._init_op = self._initialize() is_initialized_op = self.is_initialized() resources.register_resource(self.resource_handle, self._init_op, is_initialized_op) self._saveable = QuantileAccumulatorSaveable( self.resource_handle, self._init_op, self._num_streams, self.resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
def __init__(self, type_name, name, container, config, resource_handle_func, create_op_func, is_initialized_op_func, serialize_op_func, deserialize_op_func): with ops.name_scope(name, type_name) as name: self._resource_handle = resource_handle_func( container, shared_name=name, name=name) self._is_initialized_op = is_initialized_op_func(self._resource_handle) tensor = serialize_op_func(self._resource_handle) self._create_op = create_op_func(self._resource_handle, config) # slice_spec is useful for saving a slice from a variable. # It's not meaningful the tree variable. So we just pass an empty # value. slice_spec = '' specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)] super(TreeVariableSaveable, self).__init__(self._resource_handle, specs, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) resources.register_resource(self._resource_handle, self._create_op, self._is_initialized_op) self._deserialize_op_func = deserialize_op_func
def __init__(self, initial_value=None, name=None, trainable=True, collections=None, dtype=None, shape=None): """Creates a variable. Args: initial_value: An `Output` or Python object convertible to an `Output` representing the initial value of this variable. name: The name of this variable. Automatically uniquified. trainable: Whether the global read of this variable will be used for training. collections: Additional collections to which the `read` operation for this variable is to be added. Defaults to []. dtype: The type of this variable. Can be omitted if it can be deduced from the initial_value. If different from the type of the initial value it will be cast to this type. shape: The shape of this variable. Only specify if there is no initial value but shape inference is desired. """ if initial_value is not None: initial_value = ops.convert_to_tensor(initial_value) if dtype is None: assert initial_value is not None, ( "Trying to create a resource variable " "with no dtype or initial value. At" " least one of these must be set.") dtype = initial_value.dtype elif initial_value is not None: initial_value = math_ops.cast(initial_value, dtype) if shape is None: if initial_value is not None: shape = initial_value.get_shape().as_proto() else: shape = tensor_shape.unknown_shape() else: shape = tensor_shape.as_shape(shape) self._dtype = dtype with ops.name_scope(name, "Variable", [initial_value]) as name: self._handle = gen_resource_variable_ops.var_handle_op( shared_name=name, name=name, dtype=dtype, shape=shape) with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Create"): self._initialize_op = gen_resource_variable_ops.create_variable_op( self._handle, initial_value) resources.register_resource(self._handle, self._initialize_op, self._is_initialized_op) with ops.name_scope("Read"): self._value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) _register_variable_read(self._value, trainable=trainable, collections=collections)
def __init__(self, stamp_token, gradient_shape, hessian_shape, name=None, container=None): """Creates a stats accumulator and returns a handle to it. Args: stamp_token: An int64, initial value to use for the stamp token. gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. name: A name for the stats accumulator variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: # Both values are scalars. if (gradient_shape == tensor_shape.scalar() and hessian_shape == tensor_shape.scalar()): self._is_scalar = True self._resource_handle = ( gen_stats_accumulator_ops. stats_accumulator_scalar_resource_handle_op(container, name, name=name)) create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( self._resource_handle, stamp_token) is_initialized_op = (gen_stats_accumulator_ops. stats_accumulator_scalar_is_initialized( self._resource_handle)) else: self._is_scalar = False self._resource_handle = ( gen_stats_accumulator_ops. stats_accumulator_tensor_resource_handle_op(container, name, name=name)) create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( self._resource_handle, stamp_token, gradient_shape.as_list(), hessian_shape.as_list()) is_initialized_op = (gen_stats_accumulator_ops. stats_accumulator_tensor_is_initialized( self._resource_handle)) self._create_op = create_op slice_spec = "" saver_name = self._resource_handle.name (stamp_token, num_updates, partition_ids, feature_ids, gradients, hessians) = self.serialize() specs = [ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, saver_name + "_stamp"), saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, saver_name + "_num_updates"), saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, saver_name + "_partition_ids"), saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, saver_name + "_feature_ids"), saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, saver_name + "_gradients"), saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, saver_name + "hessians"), ] super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) resources.register_resource(self._resource_handle, create_op, is_initialized_op) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
def __init__(self, stamp_token, gradient_shape, hessian_shape, name=None, container=None): """Creates a stats accumulator and returns a handle to it. Args: stamp_token: An int64, initial value to use for the stamp token. gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. name: A name for the stats accumulator variable. container: An optional `string`. Defaults to `""`. Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ with ops.name_scope(name, "StatsAccumulator") as name: # Both values are scalars. if (gradient_shape == tensor_shape.scalar() and hessian_shape == tensor_shape.scalar()): self._is_scalar = True self._resource_handle = (gen_stats_accumulator_ops. stats_accumulator_scalar_resource_handle_op( container, name, name=name)) create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( self._resource_handle, stamp_token) is_initialized_op = ( gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( self._resource_handle)) else: self._is_scalar = False self._resource_handle = (gen_stats_accumulator_ops. stats_accumulator_tensor_resource_handle_op( container, name, name=name)) create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( self._resource_handle, stamp_token, gradient_shape.as_list(), hessian_shape.as_list()) is_initialized_op = ( gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( self._resource_handle)) self._create_op = create_op slice_spec = "" saver_name = self._resource_handle.name (stamp_token, num_updates, partition_ids, feature_ids, gradients, hessians) = self.serialize() specs = [ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, saver_name + "_stamp"), saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, saver_name + "_num_updates"), saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, saver_name + "_partition_ids"), saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, saver_name + "_feature_ids"), saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, saver_name + "_gradients"), saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, saver_name + "hessians"), ] super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) resources.register_resource(self._resource_handle, create_op, is_initialized_op) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)