def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): def host_computation(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) # TPU mode # Get the embeddings from the LazyBuilder. tensor = inputs.get(self.get_feature_key_name()) # Add to collection for _create_tpu_embedding_variables_and_ops _record_variable_scope_and_name(self.get_embedding_var_name(), 'embedding_weights', is_shared_embedding=True) return tensor
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): def host_computation(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) # TPU mode # Get the embeddings from the LazyBuilder. tensor = inputs.get(self.get_feature_key_name()) # Add to collection for _create_tpu_embedding_variables_and_ops _record_variable_scope_and_name(self.get_embedding_var_name(), 'embedding_weights') return tensor
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): # TODO(shizhiw, b/112012627, b/112336539): Replace _outside_all_rewrites() # with outside compilation. with _outside_all_rewrites(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) if _is_running_on_cpu(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) # TPU mode # Get the embeddings from the LazyBuilder. tensor = inputs.get(self.get_feature_key_name()) # Add to collection for _create_tpu_embedding_variables_and_ops _record_variable_scope_and_name(self.get_embedding_var_name(), 'embedding_weights', is_shared_embedding=True) return tensor