def map_function(x): if math_ops.equal(x, 0): return check_ops.ensure_shape( script_ops.py_func(sleep, [x], x.dtype, stateful=False), ()) else: return x
def embedding_lookup_sparse(embedding_variable, sp_ids, slot_num, training=True): """ This function is a wrapper of SOK's sparse forward propagation. """ if not isinstance(sp_ids, sparse_tensor.SparseTensor): raise TypeError("sp_ids must be SparseTensor") values = sp_ids.values indices = check_ops.ensure_shape(sp_ids.indices, shape=(None, 2)) row_indices = array_ops.transpose(indices, perm=[1, 0])[0] embedding_layer = embedding_variable.embedding_layer resource_variable_ops.variable_accessed(embedding_variable) comm_tool = _get_comm_tool() return kit_lib.plugin_sparse_fprop(embedding_variable._handle, embedding_layer.handle, values, row_indices, get_global_replica_id(comm_tool), slot_num=slot_num, training=training, unique_op_name=embedding_variable.name, dtype=embedding_layer.compute_dtype)
def testInvalidEnsureShape(self): with self.session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = check_ops.ensure_shape(p, (None, 3, 3)) with self.assertRaisesRegex(errors_impl.InvalidArgumentError, "is not compatible with expected shape"): sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})
def testEnsureShape(self): with self.session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = check_ops.ensure_shape(p, (None, 3)) expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] self.assertAllEqual(expected_out, sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}))
def _get_vocab_and_ids(self): export = getattr(self._vocab_lookup_table, 'export', None) if export is None: table = getattr(self._vocab_lookup_table, '_table') export = table.export vocab, ids = export() # pylint: disable=protected-access # `.export` doesn't set the shapes. vocab = check_ops.ensure_shape(vocab, [ None, ]) ids = check_ops.ensure_shape(ids, [ None, ]) order = sort_ops.argsort(ids) ids = array_ops.gather(ids, order) vocab = array_ops.gather(vocab, order) return vocab, ids
def call(self, inputs): if isinstance(inputs, ragged_tensor.RaggedTensor): # Convert the ragged tensor to a padded uniform tensor outputs = inputs.to_tensor(default_value=self._pad_value, shape=self._shape) elif isinstance(inputs, sparse_tensor.SparseTensor): # Fill in the missing value in the sparse_tensor outputs = sparse_ops.sparse_tensor_to_dense( inputs, default_value=self._pad_value) if self._shape is not None: outputs = check_ops.ensure_shape(outputs, shape=self._shape) elif isinstance(inputs, ops.Tensor): outputs = inputs if self._shape is not None: outputs = check_ops.ensure_shape(outputs, shape=self._shape) else: raise TypeError('Unexpected tensor type %s' % type(inputs).__name__) if self._mask: outputs = self.masking_layer(outputs) return outputs