def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): def host_computation(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) # TPU mode # Get the embeddings from the LazyBuilder. tensor = inputs.get(self.get_feature_key_name()) # Add to collection for _create_tpu_embedding_variables_and_ops _record_variable_scope_and_name(self.get_embedding_var_name(), 'embedding_weights', is_shared_embedding=True) return tensor
def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. Returns: A tuple (result_tensor_copy, tracing_ops), where: result_tensor_copy: an exact copy of result_tensor tracing_ops: a list of tracing ops. If this list is non empty, the caller of this function should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. """ self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph( num_replicas, result_tensor) (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] checkpoint_operations = self._get_checkpoints(graph) for op_id, op in enumerate(operations): if checkpoint_operations and op.name not in checkpoint_operations: continue user_included = self._is_user_included_op(op) user_excluded = self._is_user_excluded_op(op) if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) self._post_tracing(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops)
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): def host_computation(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) # TPU mode # Get the embeddings from the LazyBuilder. tensor = inputs.get(self.get_feature_key_name()) # Add to collection for _create_tpu_embedding_variables_and_ops _record_variable_scope_and_name(self.get_embedding_var_name(), 'embedding_weights') return tensor
def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. Returns: A tuple (result_tensor_copy, tracing_ops), where: result_tensor_copy: an exact copy of result_tensor tracing_ops: a list of tracing ops. If this list is non empty, the caller of this function should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. """ def _cast_unsupported_dtypes(tensor): """Casts tensor to a supported type.""" if tensor.dtype.__eq__(dtypes.int64): # outside-compilation doesn't support int64 input yet. return math_ops.cast(tensor, dtypes.int32) if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( dtypes.float16): # Since host can't handle bf16, convert tensor to f32. return math_ops.cast(tensor, dtypes.float32) return tensor self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph( num_replicas, result_tensor) (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] checkpoint_operations = self._get_checkpoints(graph) for op_id, op in enumerate(operations): if checkpoint_operations and op.name not in checkpoint_operations: continue user_included = self._is_user_included_op(op) user_excluded = self._is_user_excluded_op(op) if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue # Create the list of consumers before calling _preprocess_traced_tensor. # Otherwise, adding control input below, will introduce a cycle in the # graph. consumers = out_tensor.consumers() tensor_name = out_tensor.name processed_out_tensor = self._preprocess_traced_tensor( out_tensor) processed_out_tensor = _cast_unsupported_dtypes( processed_out_tensor) trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(tensor_name), processed_out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) self._post_tracing(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops)
def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. Returns: A tuple (result_tensor_copy, tracing_ops), where: result_tensor_copy: an exact copy of result_tensor tracing_ops: a list of tracing ops. If this list is non empty, the caller of this function should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. """ self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph( num_replicas, result_tensor) self._write_config_section() tracing_ops = [] operations = graph.get_operations() self._write_op_list_section(operations) # Does the topological sort before adding any nodes to the graph. (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) for op_id, op in enumerate(operations): if not self._inside_op_range(op_id): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_OUTSIDE_OP_RANGE) continue if not TensorTracer.should_trace(self._device_type, op): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_SHOULD_NOT_TRACE) continue if not self._is_selected_op(op.name): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_FILTERED_OUT) continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if not out_tensor.get_shape().is_fully_defined(): self._instrument_records[ out_tensor.name] = TensorTracer.reason( op_id, _RECORD_DYNAMIC_SHAPE) continue # cannot trace tensors with dynamic shape. rank = len(out_tensor.shape) if rank < 1: self._instrument_records[ out_tensor.name] = TensorTracer.reason( op_id, _RECORD_SCALAR) continue # cannot trace scalar. self._instrument_records[ out_tensor.name] = TensorTracer.reason( op_id, _RECORD_GET_TRACED) consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) self._write_reason_section() self._write_graph_section(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops)
def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. Returns: A tuple (result_tensor_copy, tracing_ops), where: result_tensor_copy: an exact copy of result_tensor tracing_ops: a list of tracing ops. If this list is non empty, the caller of this function should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. """ self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) self._write_config_section() tracing_ops = [] operations = graph.get_operations() self._write_op_list_section(operations) # Does the topological sort before adding any nodes to the graph. (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) for op_id, op in enumerate(operations): if not self._inside_op_range(op_id): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_OUTSIDE_OP_RANGE) continue if not TensorTracer.should_trace(self._device_type, op): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_SHOULD_NOT_TRACE) continue if not self._is_selected_op(op.name): self._instrument_records[op.name] = TensorTracer.reason( op_id, _RECORD_FILTERED_OUT) continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if not out_tensor.get_shape().is_fully_defined(): self._instrument_records[out_tensor.name] = TensorTracer.reason( op_id, _RECORD_DYNAMIC_SHAPE) continue # cannot trace tensors with dynamic shape. rank = len(out_tensor.shape) if rank < 1: self._instrument_records[out_tensor.name] = TensorTracer.reason( op_id, _RECORD_SCALAR) continue # cannot trace scalar. self._instrument_records[out_tensor.name] = TensorTracer.reason( op_id, _RECORD_GET_TRACED) consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) self._write_reason_section() self._write_graph_section(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops)
def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. fetches: the list of fetches given to session.run, used to determine the ops in execution path. If None, the whole graph will be traced. Returns: A tuple (result_tensor_copy, tracing_ops), where: result_tensor_copy: an exact copy of result_tensor tracing_ops: a list of tracing ops. If this list is non empty, the caller of this function should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. """ def _cast_unsupported_dtypes(tensor): """Casts tensor to a supported type.""" if tensor.dtype.__eq__(dtypes.int64): # outside-compilation doesn't support int64 input yet. return math_ops.cast(tensor, dtypes.int32) if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( dtypes.float16): # Since host can't handle bf16, convert tensor to f32. return math_ops.cast(tensor, dtypes.float32) return tensor self._device_type = _DEVICE_TYPE_TPU TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) # Filter out the operations that won't be executed. # if fetches=None, then ops_in_exec_path = set(operations) ops_in_exec_path = self._filter_execution_path_operations(operations, fetches) tracing_ops = [] checkpoint_operations = self._get_checkpoints(graph) for op_id, op in enumerate(operations): if checkpoint_operations and op.name not in checkpoint_operations: continue user_included = self._is_user_included_op(op) user_excluded = self._is_user_excluded_op(op) in_exec_path = op in ops_in_exec_path if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] if self._skip_tensor(op_id, out_tensor, user_included, user_excluded): continue # Create the list of consumers before calling _preprocess_traced_tensor. # Otherwise, adding control input below, will introduce a cycle in the # graph. consumers = out_tensor.consumers() tensor_name = out_tensor.name processed_out_tensor = self._preprocess_traced_tensor(out_tensor) processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(tensor_name), processed_out_tensor) if consumers: for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) self._post_tracing(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops)