Example #1
0
def dot_general(lhs,
                rhs,
                dimension_numbers,
                precision_config=None,
                preferred_element_type=None,
                name=None):
    precision_config_proto = ""
    if precision_config:
        precision_config_proto = precision_config.SerializeToString()
    needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
    if preferred_element_type is None:
        preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
    if needs_v2:
        return gen_xla_ops.xla_dot_v2(
            lhs,
            rhs,
            dimension_numbers=dimension_numbers.SerializeToString(),
            precision_config=precision_config_proto,
            preferred_element_type=preferred_element_type,
            name=name)
    return gen_xla_ops.xla_dot(
        lhs,
        rhs,
        dimension_numbers=dimension_numbers.SerializeToString(),
        precision_config=precision_config_proto,
        name=name)
Example #2
0
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
  precision_config_proto = ""
  if precision_config:
    precision_config_proto = precision_config.SerializeToString()
  return gen_xla_ops.xla_dot(
      lhs,
      rhs,
      dimension_numbers=dimension_numbers.SerializeToString(),
      precision_config=precision_config_proto,
      name=name)
Example #3
0
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
  precision_config_proto = ""
  if precision_config:
    precision_config_proto = precision_config.SerializeToString()
  return gen_xla_ops.xla_dot(
      lhs,
      rhs,
      dimension_numbers=dimension_numbers.SerializeToString(),
      precision_config=precision_config_proto,
      name=name)