def _read_variable_op(self, do_prefetch=True): resource_variable_ops.variable_accessed(self) if self.model_mode == "train": if do_prefetch: with ops.control_dependencies([ gen_resource_variable_ops.assign_variable_op( self._handle, self.prefetch_values(), name="AssignBeforeReadVariable") ]): _result = gen_resource_variable_ops.read_variable_op( self._handle, self._dtype) else: _result = gen_resource_variable_ops.read_variable_op( self._handle, self._dtype) else: _result = self.prefetch_values() if not context.executing_eagerly(): # Note that if a control flow context is active the input of the read op # might not actually be the handle. This line bypasses it. tape.record_operation("ReadVariableOp", [_result], [self._handle], lambda x: [x]) result = self.transform(_result) return result
def _skip_single_var(self, var, delta): resource_variable_ops.variable_accessed(var) # TODO(wangpeng): Cache the cast algorithm instead of casting everytime. return gen_stateful_random_ops.rng_read_and_skip( var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32), delta=math_ops.cast(delta, dtypes.uint64))
def embedding_lookup_sparse(embedding_variable, sp_ids, slot_num, training=True): """ This function is a wrapper of SOK's sparse forward propagation. """ if not isinstance(sp_ids, sparse_tensor.SparseTensor): raise TypeError("sp_ids must be SparseTensor") values = sp_ids.values indices = check_ops.ensure_shape(sp_ids.indices, shape=(None, 2)) row_indices = array_ops.transpose(indices, perm=[1, 0])[0] embedding_layer = embedding_variable.embedding_layer resource_variable_ops.variable_accessed(embedding_variable) comm_tool = _get_comm_tool() return kit_lib.plugin_sparse_fprop(embedding_variable._handle, embedding_layer.handle, values, row_indices, get_global_replica_id(comm_tool), slot_num=slot_num, training=training, unique_op_name=embedding_variable.name, dtype=embedding_layer.compute_dtype)
def _read_variable_op(self): variable_accessed(self) result = kit_lib.read_embedding_variable(self._handle, self.tf_handle, self._dtype, self.name) _maybe_set_handle_data(self._dtype, self._handle, result) if not context.executing_eagerly(): tape.record_operation("ReadEmbeddingVariableOp", [result], [self._handle, self.tf_handle], lambda x: [x, None]) return result
def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: resource_variable_ops.variable_accessed(self) default_value = self._initializer(array_ops.concat( [array_ops.shape(indices), self.shape.as_list()], axis=0), dtype=self._dtype) value = gen_ev_ops.ev_gather(self._handle, indices, default_value, name=name) return array_ops.identity(value)
def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" if indices.dtype != self._ktype: raise errors_impl.InvalidArgumentError( None, None, "type of indices is not match with EmbeddingVariable key type.") with ops.name_scope("Gather" if name is None else name) as name: resource_variable_ops.variable_accessed(self) default_value = self._initializer(array_ops.concat( [array_ops.shape(indices), self.shape.as_list()[1:]], axis=0), dtype=self.dtype) value = gen_ev_ops.ev_gather(self._handle, indices, default_value, name=name) return array_ops.identity(value)
def embedding_lookup(embedding_variable, values, training=True, dynamic_input=False): """ This function is a wrapper of SOK's dense forward propagation. """ embedding_layer = embedding_variable.embedding_layer resource_variable_ops.variable_accessed(embedding_variable) comm_tool = _get_comm_tool() return kit_lib.plugin_dense_fprop(embedding_variable._handle, embedding_layer.handle, values=values, global_replica_id=get_global_replica_id(comm_tool), training=training, unique_op_name=embedding_variable.name, dynamic_input=dynamic_input, dtype=embedding_layer.compute_dtype)
def _read_variable_op(self): variable_accessed(self) def read_and_set_handle(): result = kit_lib.read_embedding_variable(self._handle, self.tf_handle, self._dtype, self.name) _maybe_set_handle_data(self._dtype, self._handle, result) return result if getattr(self, "_caching_device", None) is not None: with ops.colocate_with(None, ignore_existing=True): with ops.device(self._caching_device): result = read_and_set_handle() else: result = read_and_set_handle() if not context.executing_eagerly(): # Note that if a control flow context is active the input of the read op # might not actually be the handle. This line bypasses it. tape.record_operation( "ReadEmbeddingVariableOp", [result], [self._handle, self.tf_handle], backward_function=lambda x: [x, None], forward_function=lambda x: [x]) return result