示例#1
0
    def test_numpy_arguments(self):
        def plus(a, b):
            return a + b

        actual_result = script_ops.numpy_function(plus, [1, 2], dtypes.int32)
        expect_result = constant_op.constant(3, dtypes.int32)
        self.assertAllEqual(actual_result, expect_result)
示例#2
0
    def callback(self,
                 op_type,
                 inputs,
                 attrs,
                 outputs,
                 op_name=None,
                 graph=None):
        is_eager = not graph
        if is_eager:
            self.eager_op_types.append(
                compat.as_bytes(op_type) if op_type else op_type)
            self.eager_op_names.append(
                compat.as_bytes(op_name) if op_name else op_name)
            self.eager_attrs.append(attrs)
            self.eager_graphs.append(graph)
            self.eager_inputs.append(inputs)
        else:
            self.graph_op_types.append(
                compat.as_bytes(op_type) if op_type else op_type)
            self.graph_op_names.append(
                compat.as_bytes(op_name) if op_name else op_name)
            self.graph_attrs.append(attrs)
            self.graph_graphs.append(graph)
            self.graph_graph_versions.append(graph.version)
            self.graph_inputs.append(inputs)

            if not self.instrument_graph_ops:
                return outputs

            # Instrument the graph with numpy_function.
            instrumented_outputs = []
            for output in outputs:
                if compat.as_bytes(op_type) in (_IF_OP, _STATELESS_IF_OP,
                                                _WHILE_OP, _IDENTITY_OP,
                                                _VAR_HANDLE_OP):
                    # TODO(cais): Overriding the output of StatelessIf, If and While ops
                    # currently fails with error. Investigate (b/139668453).
                    # Avoid instrumenting Identity ops as well, as they are inserted
                    # by tf.function/AutoGraph for marshalling outputs.
                    instrumented_output = output
                else:

                    def record(ndarray_value):
                        if compat.as_bytes(
                                op_name) not in self.graph_internal_ndarrays:
                            self.graph_internal_ndarrays[compat.as_bytes(
                                op_name)] = []
                        self.graph_internal_ndarrays[compat.as_bytes(
                            op_name)].append(ndarray_value)
                        return ndarray_value

                    instrumented_output = script_ops.numpy_function(
                        record, [output], output.dtype)
                    instrumented_output.set_shape(output.shape)
                instrumented_outputs.append(instrumented_output)

            return instrumented_outputs
示例#3
0
def compress(summary, epsilon):
  """Compress a summary to within `epsilon` accuracy.

  The compression step is needed to keep the summary sizes small after merging,
  and also used to return the final target boundaries. It finds the new bins
  based on interpolating cumulative weight percentages from the large summary.
  Taking the difference of the cumulative weights from the previous bin's
  cumulative weight will give the new weight for that bin.

  Args:
      summary: 2-D `np.ndarray` summary to be compressed.
      epsilon: A `'float32'` that determines the approxmiate desired precision.

  Returns:
      A 2-D `np.ndarray` that is a compressed summary. First column is the
      interpolated partition values, the second is the weights (counts).
  """
  # TODO(b/184863356): remove the numpy escape hatch here.
  return script_ops.numpy_function(
      lambda s: _compress_summary_numpy(s, epsilon), [summary], dtypes.float32)
    def _filter_top_k(x):
      # This loses the static shape.
      x = script_ops.numpy_function(_identity, (x,), dtypes.float32)

      return metrics_utils._filter_top_k(x=x, k=2)
示例#5
0
 def numpy_func_stateful(a, b):
     return numpy_function(plus, [a, b], dtypes.int32, stateful=True)
示例#6
0
 def numpy_func_stateless(a, b):
     return numpy_function(plus, [a, b], dtypes.int32, stateful=False)
 def to_upper(x):
     return script_ops.numpy_function(
         func=lambda x: x.decode("utf-8").upper(),
         inp=[x],
         Tout=dtypes.string)