def make_function_signature(function_args, signature_context: SignatureContext, encode_variables_by_resource_id, use_full_trace_type) -> trace.TraceType: """Returns the trace type specification of a function's arguments. Args: function_args: Tuple/List/Dict structure containing the function arguments signature_context: The SignatureContext to be shared during protocol calls. encode_variables_by_resource_id: If Variables should be considered by resource id use_full_trace_type: Uses the TraceType protocol wherever possible. Returns: A TraceType object representing all the given inputs. """ try: encoding = pywrap_tfe.TFE_Py_EncodeArg( function_args, signature_context, signature_context.include_tensor_ranks_only, encode_variables_by_resource_id, use_full_trace_type) if use_full_trace_type: return encoding else: # TODO(b/201533914): Drop when use_full_trace_type flag is removed. return GenericType(encoding) except core._NotOkStatusException as e: # pylint: disable=protected-access raise core._status_to_exception(e) from None # pylint: disable=protected-access
def get_arg_spec(inputs, include_tensor_ranks_only, encode_variables_by_resource_id, use_full_trace_type): """Returns the trace type specification of a function's arguments. Args: inputs: Tuple/List/Dict structure containing the function arguments include_tensor_ranks_only: If Tensors should be considered by rank encode_variables_by_resource_id: If Variables should be considered by resource id use_full_trace_type: Uses the TraceType protocol wherever possible. Returns: A TraceType object representing the function arguments. """ signature_context = SignatureContext(include_tensor_ranks_only) try: encoding = pywrap_tfe.TFE_Py_EncodeArg(inputs, signature_context, include_tensor_ranks_only, encode_variables_by_resource_id, use_full_trace_type) if use_full_trace_type: return encoding else: # TODO(b/201533914): Drop when use_full_trace_type flag is removed. return GenericType(encoding) except core._NotOkStatusException as e: # pylint: disable=protected-access raise core._status_to_exception(e) from None # pylint: disable=protected-access
def get_arg_spec(inputs, include_tensor_ranks_only, encode_variables_by_resource_id): """Returns the trace type specification of a function's arguments. Args: inputs: Tuple/List/Dict structure containing the function arguments include_tensor_ranks_only: If Tensors should be considered by rank encode_variables_by_resource_id: If Variables should be considered by resource id Returns: A hashable object representing the function arguments. """ return _make_input_signature_hashable( pywrap_tfe.TFE_Py_EncodeArg(inputs, include_tensor_ranks_only, encode_variables_by_resource_id))
def make_function_signature( function_args, signature_context: SignatureContext) -> trace.TraceType: """Returns the trace type specification of a function's arguments. Args: function_args: Tuple/List/Dict structure containing the function arguments signature_context: The SignatureContext to be shared during protocol calls. Returns: A TraceType object representing all the given inputs. """ try: return pywrap_tfe.TFE_Py_EncodeArg(function_args, signature_context) except core._NotOkStatusException as e: # pylint: disable=protected-access raise core._status_to_exception(e) from None # pylint: disable=protected-access
def get_arg_spec(inputs, include_tensor_ranks_only, encode_variables_by_resource_id, use_full_trace_type): """Returns the trace type specification of a function's arguments. Args: inputs: Tuple/List/Dict structure containing the function arguments include_tensor_ranks_only: If Tensors should be considered by rank encode_variables_by_resource_id: If Variables should be considered by resource id use_full_trace_type: Uses the TraceType protocol wherever possible. Returns: A TraceType object representing the function arguments. """ # TODO(b/201533914): Drop GenericType once TFE_Py_EncodeArg returns TraceType. signature_context = SignatureContext() return GenericType( pywrap_tfe.TFE_Py_EncodeArg(inputs, signature_context, include_tensor_ranks_only, encode_variables_by_resource_id, use_full_trace_type))