def dot_general(lhs, rhs, dimension_numbers, precision_config=None, preferred_element_type=None, name=None): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) if preferred_element_type is None: preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) if needs_v2: return gen_xla_ops.xla_dot_v2( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, preferred_element_type=preferred_element_type, name=name) return gen_xla_ops.xla_dot( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, name=name)
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() return gen_xla_ops.xla_dot( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, name=name)
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() return gen_xla_ops.xla_dot( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, name=name)