Exemplo n.º 1
0
    def _get_dense_tensor(self,
                          inputs,
                          weight_collections=None,
                          trainable=None):
        if tpu.under_tpu_inference_context():

            def host_computation():
                return fc._SharedEmbeddingColumn._get_dense_tensor(
                    self, inputs, weight_collections, trainable)

            return tpu.outside_compilation(host_computation)

        if _is_running_on_cpu():
            return fc._SharedEmbeddingColumn._get_dense_tensor(
                self, inputs, weight_collections, trainable)

        # TPU mode
        # Get the embeddings from the LazyBuilder.
        tensor = inputs.get(self.get_feature_key_name())

        # Add to collection for _create_tpu_embedding_variables_and_ops
        _record_variable_scope_and_name(self.get_embedding_var_name(),
                                        'embedding_weights',
                                        is_shared_embedding=True)
        return tensor
Exemplo n.º 2
0
    def trace_tpu(self, graph, result_tensor, num_replicas=None):
        """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops executed on the TPU.
      result_tensor: a result tensor of evaluating the graph.
      num_replicas: number of replicas used on the TPU.

    Returns:
      A tuple (result_tensor_copy, tracing_ops), where:
        result_tensor_copy: an exact copy of result_tensor
        tracing_ops: a list of tracing ops. If this list
                     is non empty, the caller of this function
                     should pose control dependencies upon these
                     Ops so that they will be executed when the
                     graph is evaluated.
    """

        self._device_type = _DEVICE_TYPE_TPU
        TensorTracer.check_device_type(self._device_type)
        result_tensor_copy = self._add_replica_id_to_graph(
            num_replicas, result_tensor)
        (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
        tracing_ops = []
        checkpoint_operations = self._get_checkpoints(graph)

        for op_id, op in enumerate(operations):
            if checkpoint_operations and op.name not in checkpoint_operations:
                continue
            user_included = self._is_user_included_op(op)
            user_excluded = self._is_user_excluded_op(op)
            if self._skip_op(op_id, op, user_included, user_excluded):
                continue
            for i in range(len(op.outputs)):
                out_tensor = op.outputs[i]
                if self._skip_tensor(op_id, out_tensor, user_included,
                                     user_excluded):
                    continue
                consumers = out_tensor.consumers()
                trace_op = tpu.outside_compilation(
                    self._make_tensor_trace_fun(op.name, i), out_tensor)
                if consumers:
                    for consumer_op in consumers:
                        # pylint: disable=protected-access
                        consumer_op._add_control_input(trace_op)
                        # pylint: enable=protected-access
                else:
                    # if there is no consumer, we will add the control dependence later
                    # when we add the control dependency to the output operations.
                    tracing_ops.append(trace_op)
        self._post_tracing(succeed, sorted_or_cycle)
        return (result_tensor_copy, tracing_ops)
Exemplo n.º 3
0
  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
    if tpu.under_tpu_inference_context():
      def host_computation():
        return fc._EmbeddingColumn._get_dense_tensor(
            self, inputs, weight_collections, trainable)
      return tpu.outside_compilation(host_computation)

    if _is_running_on_cpu():
      return fc._EmbeddingColumn._get_dense_tensor(
          self, inputs, weight_collections, trainable)

    # TPU mode
    # Get the embeddings from the LazyBuilder.
    tensor = inputs.get(self.get_feature_key_name())

    # Add to collection for _create_tpu_embedding_variables_and_ops
    _record_variable_scope_and_name(self.get_embedding_var_name(),
                                    'embedding_weights')

    return tensor
Exemplo n.º 4
0
    def trace_tpu(self, graph, result_tensor, num_replicas=None):
        """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops executed on the TPU.
      result_tensor: a result tensor of evaluating the graph.
      num_replicas: number of replicas used on the TPU.

    Returns:
      A tuple (result_tensor_copy, tracing_ops), where:
        result_tensor_copy: an exact copy of result_tensor
        tracing_ops: a list of tracing ops. If this list
                     is non empty, the caller of this function
                     should pose control dependencies upon these
                     Ops so that they will be executed when the
                     graph is evaluated.
    """
        def _cast_unsupported_dtypes(tensor):
            """Casts tensor to a supported type."""

            if tensor.dtype.__eq__(dtypes.int64):
                # outside-compilation doesn't support int64 input yet.
                return math_ops.cast(tensor, dtypes.int32)
            if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
                    dtypes.float16):
                # Since host can't handle bf16, convert tensor to f32.
                return math_ops.cast(tensor, dtypes.float32)
            return tensor

        self._device_type = _DEVICE_TYPE_TPU
        TensorTracer.check_device_type(self._device_type)
        result_tensor_copy = self._add_replica_id_to_graph(
            num_replicas, result_tensor)
        (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
        tracing_ops = []
        checkpoint_operations = self._get_checkpoints(graph)

        for op_id, op in enumerate(operations):
            if checkpoint_operations and op.name not in checkpoint_operations:
                continue
            user_included = self._is_user_included_op(op)
            user_excluded = self._is_user_excluded_op(op)
            if self._skip_op(op_id, op, user_included, user_excluded):
                continue
            for i in range(len(op.outputs)):
                out_tensor = op.outputs[i]
                if self._skip_tensor(op_id, out_tensor, user_included,
                                     user_excluded):
                    continue
                # Create the list of consumers before calling _preprocess_traced_tensor.
                # Otherwise, adding control input below, will introduce a cycle in the
                # graph.
                consumers = out_tensor.consumers()
                tensor_name = out_tensor.name
                processed_out_tensor = self._preprocess_traced_tensor(
                    out_tensor)
                processed_out_tensor = _cast_unsupported_dtypes(
                    processed_out_tensor)
                trace_op = tpu.outside_compilation(
                    self._make_tensor_trace_fun(tensor_name),
                    processed_out_tensor)
                if consumers:
                    for consumer_op in consumers:
                        # pylint: disable=protected-access
                        consumer_op._add_control_input(trace_op)
                        # pylint: enable=protected-access
                else:
                    # if there is no consumer, we will add the control dependence later
                    # when we add the control dependency to the output operations.
                    tracing_ops.append(trace_op)
        self._post_tracing(succeed, sorted_or_cycle)
        return (result_tensor_copy, tracing_ops)
Exemplo n.º 5
0
    def trace_tpu(self, graph, result_tensor, num_replicas=None):
        """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops.
      result_tensor: a result tensor of evaluating the graph.
      num_replicas: number of replicas used on the TPU.

    Returns:
      A tuple (result_tensor_copy, tracing_ops), where:
        result_tensor_copy: an exact copy of result_tensor
        tracing_ops: a list of tracing ops. If this list
                     is non empty, the caller of this function
                     should pose control dependencies upon these
                     Ops so that they will be executed when the
                     graph is evaluated.
    """

        self._device_type = _DEVICE_TYPE_TPU
        TensorTracer.check_device_type(self._device_type)
        result_tensor_copy = self._add_replica_id_to_graph(
            num_replicas, result_tensor)
        self._write_config_section()
        tracing_ops = []
        operations = graph.get_operations()
        self._write_op_list_section(operations)
        # Does the topological sort before adding any nodes to the graph.
        (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
        for op_id, op in enumerate(operations):
            if not self._inside_op_range(op_id):
                self._instrument_records[op.name] = TensorTracer.reason(
                    op_id, _RECORD_OUTSIDE_OP_RANGE)
                continue
            if not TensorTracer.should_trace(self._device_type, op):
                self._instrument_records[op.name] = TensorTracer.reason(
                    op_id, _RECORD_SHOULD_NOT_TRACE)
                continue
            if not self._is_selected_op(op.name):
                self._instrument_records[op.name] = TensorTracer.reason(
                    op_id, _RECORD_FILTERED_OUT)
                continue
            for i in range(len(op.outputs)):
                out_tensor = op.outputs[i]
                if not out_tensor.get_shape().is_fully_defined():
                    self._instrument_records[
                        out_tensor.name] = TensorTracer.reason(
                            op_id, _RECORD_DYNAMIC_SHAPE)
                    continue  # cannot trace tensors with dynamic shape.
                rank = len(out_tensor.shape)
                if rank < 1:
                    self._instrument_records[
                        out_tensor.name] = TensorTracer.reason(
                            op_id, _RECORD_SCALAR)
                    continue  # cannot trace scalar.
                self._instrument_records[
                    out_tensor.name] = TensorTracer.reason(
                        op_id, _RECORD_GET_TRACED)
                consumers = out_tensor.consumers()
                trace_op = tpu.outside_compilation(
                    self._make_tensor_trace_fun(op.name, i), out_tensor)
                if consumers:
                    for consumer_op in consumers:
                        # pylint: disable=protected-access
                        consumer_op._add_control_input(trace_op)
                        # pylint: enable=protected-access
                else:
                    # if there is no consumer, we will add the control dependence later
                    # when we add the control dependency to the output operations.
                    tracing_ops.append(trace_op)

        self._write_reason_section()
        self._write_graph_section(succeed, sorted_or_cycle)

        return (result_tensor_copy, tracing_ops)
Exemplo n.º 6
0
  def trace_tpu(self, graph, result_tensor, num_replicas=None):
    """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops.
      result_tensor: a result tensor of evaluating the graph.
      num_replicas: number of replicas used on the TPU.

    Returns:
      A tuple (result_tensor_copy, tracing_ops), where:
        result_tensor_copy: an exact copy of result_tensor
        tracing_ops: a list of tracing ops. If this list
                     is non empty, the caller of this function
                     should pose control dependencies upon these
                     Ops so that they will be executed when the
                     graph is evaluated.
    """

    self._device_type = _DEVICE_TYPE_TPU
    TensorTracer.check_device_type(self._device_type)
    result_tensor_copy = self._add_replica_id_to_graph(num_replicas,
                                                       result_tensor)
    self._write_config_section()
    tracing_ops = []
    operations = graph.get_operations()
    self._write_op_list_section(operations)
    # Does the topological sort before adding any nodes to the graph.
    (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
    for op_id, op in enumerate(operations):
      if not self._inside_op_range(op_id):
        self._instrument_records[op.name] = TensorTracer.reason(
            op_id, _RECORD_OUTSIDE_OP_RANGE)
        continue
      if not TensorTracer.should_trace(self._device_type, op):
        self._instrument_records[op.name] = TensorTracer.reason(
            op_id, _RECORD_SHOULD_NOT_TRACE)
        continue
      if not self._is_selected_op(op.name):
        self._instrument_records[op.name] = TensorTracer.reason(
            op_id, _RECORD_FILTERED_OUT)
        continue
      for i in range(len(op.outputs)):
        out_tensor = op.outputs[i]
        if not out_tensor.get_shape().is_fully_defined():
          self._instrument_records[out_tensor.name] = TensorTracer.reason(
              op_id, _RECORD_DYNAMIC_SHAPE)
          continue  # cannot trace tensors with dynamic shape.
        rank = len(out_tensor.shape)
        if rank < 1:
          self._instrument_records[out_tensor.name] = TensorTracer.reason(
              op_id, _RECORD_SCALAR)
          continue  # cannot trace scalar.
        self._instrument_records[out_tensor.name] = TensorTracer.reason(
            op_id, _RECORD_GET_TRACED)
        consumers = out_tensor.consumers()
        trace_op = tpu.outside_compilation(
            self._make_tensor_trace_fun(op.name, i), out_tensor)
        if consumers:
          for consumer_op in consumers:
            # pylint: disable=protected-access
            consumer_op._add_control_input(trace_op)
            # pylint: enable=protected-access
        else:
          # if there is no consumer, we will add the control dependence later
          # when we add the control dependency to the output operations.
          tracing_ops.append(trace_op)

    self._write_reason_section()
    self._write_graph_section(succeed, sorted_or_cycle)

    return (result_tensor_copy, tracing_ops)
Exemplo n.º 7
0
  def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None):
    """Traces the tensors generated by TPU Ops in a TF graph.

    Args:
      graph: the graph of Ops executed on the TPU.
      result_tensor: a result tensor of evaluating the graph.
      num_replicas: number of replicas used on the TPU.
      fetches: the list of fetches given to session.run, used to determine the
      ops in execution path. If None, the whole graph will be traced.

    Returns:
      A tuple (result_tensor_copy, tracing_ops), where:
        result_tensor_copy: an exact copy of result_tensor
        tracing_ops: a list of tracing ops. If this list
                     is non empty, the caller of this function
                     should pose control dependencies upon these
                     Ops so that they will be executed when the
                     graph is evaluated.
    """

    def _cast_unsupported_dtypes(tensor):
      """Casts tensor to a supported type."""

      if tensor.dtype.__eq__(dtypes.int64):
        # outside-compilation doesn't support int64 input yet.
        return math_ops.cast(tensor, dtypes.int32)
      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
          dtypes.float16):
        # Since host can't handle bf16, convert tensor to f32.
        return math_ops.cast(tensor, dtypes.float32)
      return tensor

    self._device_type = _DEVICE_TYPE_TPU
    TensorTracer.check_device_type(self._device_type)
    result_tensor_copy = self._add_replica_id_to_graph(num_replicas,
                                                       result_tensor)
    (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
    # Filter out the operations that won't be executed.
    # if fetches=None, then ops_in_exec_path = set(operations)
    ops_in_exec_path = self._filter_execution_path_operations(operations,
                                                              fetches)
    tracing_ops = []
    checkpoint_operations = self._get_checkpoints(graph)

    for op_id, op in enumerate(operations):
      if checkpoint_operations and op.name not in checkpoint_operations:
        continue
      user_included = self._is_user_included_op(op)
      user_excluded = self._is_user_excluded_op(op)
      in_exec_path = op in ops_in_exec_path
      if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
        continue
      for i in range(len(op.outputs)):
        out_tensor = op.outputs[i]
        if self._skip_tensor(op_id, out_tensor, user_included,
                             user_excluded):
          continue
        # Create the list of consumers before calling _preprocess_traced_tensor.
        # Otherwise, adding control input below, will introduce a cycle in the
        # graph.
        consumers = out_tensor.consumers()
        tensor_name = out_tensor.name
        processed_out_tensor = self._preprocess_traced_tensor(out_tensor)
        processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
        trace_op = tpu.outside_compilation(
            self._make_tensor_trace_fun(tensor_name), processed_out_tensor)
        if consumers:
          for consumer_op in consumers:
            # pylint: disable=protected-access
            consumer_op._add_control_input(trace_op)
            # pylint: enable=protected-access
        else:
          # if there is no consumer, we will add the control dependence later
          # when we add the control dependency to the output operations.
          tracing_ops.append(trace_op)
    self._post_tracing(succeed, sorted_or_cycle)
    return (result_tensor_copy, tracing_ops)