def __init__(self, iterator_resource, name): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") ] super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
def __init__(self, var_list=None, **kwargs): kwargs['restore_sequentially'] = False kwargs['builder'] = BaseSaverBuilder() super().__init__(var_list=var_list, **kwargs)
def _init_from_args(self, embedding_dim, initializer=None, trainable=True, collections=None, caching_device=None, name=None, ktype=None, vtype=None, constraint=None, synchronization=None, aggregation=None, distribute_strategy=None, invalid_key=-1): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the EmbeddingVariable. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. name: Optional name for the variable. Defaults to `'EmbeddingVariable'` and gets uniquified automatically. ktype: If set, EV's key will be converted to the given type. If None, int32 will be used. vtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. @compatibility(eager) When Eager Execution is enabled, variables are never added to collections. It is not implicitly added to the GLOBAL_VARIABLES or TRAINABLE_VARIABLES collections, and the `collections` argument is ignored. @end_compatibility """ if isinstance(embedding_dim, tensor_shape.TensorShape): embedding_shape = embedding_dim elif isinstance(embedding_dim, six.integer_types): embedding_shape = [embedding_dim] initial_value = initializer(shape=embedding_shape) init_from_fn = callable(initial_value) if ktype is None: ktype = dtypes.int32 if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to EmbeddingVariable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") self._initializer = initializer if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "EmbeddingVariable", [] if init_from_fn else [initial_value], skip_on_eager=False) as name: # pylint: disable=protected-access self._invalid_key = invalid_key self._ktype = ktype handle_name = ops.name_from_scope_name(name) if self._in_graph_mode: shared_name = handle_name unique_id = shared_name else: # When in eager mode use a uid for the shared_name, to prevent # accidental sharing. unique_id = "%s_%d" % (handle_name, ops.uid()) shared_name = unique_id # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. device_context_manager = (ops.device if self._in_graph_mode else ops.NullContextmanager) attr = attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), device_context_manager(None): if init_from_fn: initial_value = initial_value() if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value initial_value = ops.convert_to_tensor(initial_value, name="initial_value", dtype=vtype) shape = initial_value.shape handle = self._embedding_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) # pylint: disable=protected-access if (self._in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access vtype = initial_value.dtype.base_dtype if self._in_graph_mode: with ops.name_scope("IsInitialized"): self._ev_is_initialized_op = (gen_ev_ops.ev_is_initialized_op( handle, Tkey=self._ktype, Tvalue=vtype)) if initial_value is not None: # pylint: disable=g-backslash-continuation with ops.name_scope("Initialize") as n, \ ops.colocate_with(None, ignore_existing=True), \ ops.device(handle.device): # pylint: disable=protected-access initializer_op = (gen_ev_ops.initialize_ev_op( handle, variables._try_guard_against_uninitialized_dependencies( name, initial_value), ops.convert_to_tensor(invalid_key, dtype=self._ktype), shape=initial_value.get_shape(), name=n)) cached_value = None graph_element = None else: gen_ev_ops.initialize_ev_op(handle, initial_value, ops.convert_to_tensor(invalid_key, dtype=self._ktype), shape=initial_value.get_shape()) self._ev_is_initialized_op = None initializer_op = None graph_element = None cached_value = None if not context.executing_eagerly(): # Eager variables are only added to collections if they are part of an # eager variable store (otherwise in an interactive session they would # hog memory and cause OOM). This is done in ops/variable_scope.py. ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) initial_value = initial_value if self._in_graph_mode else None new_dim = shape.as_list() new_dim.insert(0, 0) new_shape = tensor_shape.TensorShape(new_dim) super(resource_variable_ops.ResourceVariable, self).__init__(trainable=trainable, shape=new_shape, dtype=vtype, handle=handle, synchronization=synchronization, constraint=constraint, aggregation=aggregation, distribute_strategy=distribute_strategy, name=name, unique_id=unique_id, handle_name=handle_name, graph_element=graph_element, initial_value=initial_value, initializer_op=initializer_op, is_initialized_op=self._ev_is_initialized_op, cached_value=cached_value, caching_device=caching_device) tensors = gen_ev_ops.ev_export(self.handle, Tkey=self._ktype, Tvalue=vtype) self.specs = [ BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values"), ]