Exemple #1
0
 def execute(self, sql, bind_variables, **kwargs):
   """Perform a query, return the number of rows affected."""
   self._clear_list_state()
   self._clear_batch_state()
   if self._handle_transaction_sql(sql):
     return
   entity_keyspace_id_map = kwargs.pop('entity_keyspace_id_map', None)
   entity_column_name = kwargs.pop('entity_column_name', None)
   write_query = bool(write_sql_pattern.match(sql))
   # NOTE: This check may also be done at higher layers but adding it
   # here for completion.
   if write_query:
     if not self.is_writable():
       raise dbexceptions.DatabaseError('DML on a non-writable cursor', sql)
     if entity_keyspace_id_map:
       raise dbexceptions.DatabaseError(
           'entity_keyspace_id_map is not allowed for write queries')
   self.results, self.rowcount, self.lastrowid, self.description = (
       self.connection._execute(  # pylint: disable=protected-access
           sql,
           bind_variables,
           self.keyspace,
           self.tablet_type,
           shards=self.shards,
           keyspace_ids=self.keyspace_ids,
           keyranges=self.keyranges,
           entity_keyspace_id_map=entity_keyspace_id_map,
           entity_column_name=entity_column_name,
           not_in_transaction=not self.is_writable(),
           effective_caller_id=self.effective_caller_id,
           **kwargs))
   return self.rowcount
def _convert_exception(exc, *args, **kwargs):
    """This parses the protocol exceptions to the api interface exceptions.

  This also logs the exception and increments the appropriate error counters.

  Args:
    exc: raw protocol exception.
    *args: additional args from the raising site.
    **kwargs: additional keyword args from the raising site.
              They will be converted into a single string, and added as an extra
              arg to the exception.

  Returns:
    Api interface exceptions - dbexceptions with new args.
  """
    kwargs_as_str = vtgate_utils.convert_exception_kwargs(kwargs)
    exc.args += args
    if kwargs_as_str:
        exc.args += kwargs_as_str,
    new_args = (type(exc).__name__, ) + exc.args
    if isinstance(exc, vtgate_utils.VitessError):
        new_exc = exc.convert_to_dbexception(new_args)
    elif isinstance(exc, grpc.RpcError):
        # Most RpcErrors should also implement Call so we can get details.
        if isinstance(exc, grpc.Call):
            code = exc.code()
            details = exc.details()
            if code == grpc.StatusCode.DEADLINE_EXCEEDED:
                new_exc = dbexceptions.TimeoutError(new_args)
            elif code == grpc.StatusCode.UNAVAILABLE:
                if vtgate_utils.throttler_err_re.search(details):
                    return dbexceptions.ThrottledError(new_args)
                else:
                    return dbexceptions.TransientError(new_args)
            elif code == grpc.StatusCode.ALREADY_EXISTS:
                new_exc = _prune_integrity_error(details, new_args)
            elif code == grpc.StatusCode.FAILED_PRECONDITION:
                return dbexceptions.QueryNotServed(details, new_args)
            elif code == grpc.StatusCode.INVALID_ARGUMENT:
                return dbexceptions.ProgrammingError(details, new_args)
            else:
                # Other RPC error that we don't specifically handle.
                new_exc = dbexceptions.DatabaseError(new_args +
                                                     (code, details))
        else:
            # RPC error that doesn't provide code and details.
            # Don't let gRPC-specific errors leak beyond this package.
            new_exc = dbexceptions.DatabaseError(new_args + (exc, ))
    else:
        new_exc = exc
    vtgate_utils.log_exception(new_exc,
                               keyspace=kwargs.get('keyspace'),
                               tablet_type=kwargs.get('tablet_type'))
    return new_exc
Exemple #3
0
def get_keyrange_from_shard_name(keyspace, shard_name):
  kr = None
  # db_type is immaterial here.
  if not is_sharded_keyspace(keyspace, 'replica'):
    if shard_name == keyrange_constants.SHARD_ZERO:
      kr = keyrange_constants.NON_PARTIAL_KEYRANGE
    else:
      raise dbexceptions.DatabaseError('Invalid shard_name %s for keyspace %s', shard_name, keyspace)
  else:
    kr_parts = shard_name.split('-')
    if len(kr_parts) != 2:
      raise dbexceptions.DatabaseError('Invalid shard_name %s for keyspace %s', shard_name, keyspace)
    kr = keyrange.KeyRange((kr_parts[0].decode('hex'), kr_parts[1].decode('hex')))
  return kr
Exemple #4
0
def handle_app_error(exc_args):
    msg = str(exc_args[0]).lower()

    # Operational Error
    if msg.startswith('retry'):
        return dbexceptions.RetryError(exc_args)

    if msg.startswith('fatal'):
        return dbexceptions.FatalError(exc_args)

    if msg.startswith('tx_pool_full'):
        return dbexceptions.TxPoolFull(exc_args)

    # Integrity and Database Error
    match = _errno_pattern.search(msg)
    if match:
        # Prune the error message to truncate after the mysql errno, since
        # the error message may contain the query string with bind variables.
        mysql_errno = int(match.group(1))
        if mysql_errno == 1062:
            parts = _errno_pattern.split(msg)
            pruned_msg = msg[:msg.find(parts[2])]
            new_args = (pruned_msg, ) + tuple(exc_args[1:])
            return dbexceptions.IntegrityError(new_args)
        # TODO(sougou/liguo): remove this case once servers are deployed
        elif mysql_errno == 1290 and 'read-only' in msg:
            return dbexceptions.RetryError(exc_args)

    return dbexceptions.DatabaseError(exc_args)
Exemple #5
0
def convert_exception(exc, *args):
    new_args = exc.args + args
    if isinstance(exc, gorpc.TimeoutError):
        return dbexceptions.TimeoutError(new_args)
    elif isinstance(exc, gorpc.AppError):
        msg = str(exc[0]).lower()
        if msg.startswith('retry'):
            return dbexceptions.RetryError(new_args)
        if msg.startswith('fatal'):
            return dbexceptions.FatalError(new_args)
        if msg.startswith('tx_pool_full'):
            return dbexceptions.TxPoolFull(new_args)
        match = _errno_pattern.search(msg)
        if match:
            mysql_errno = int(match.group(1))
            if mysql_errno == 1062:
                return dbexceptions.IntegrityError(new_args)
            # TODO(sougou/liguo): remove this case once servers are deployed
            elif mysql_errno == 1290 and 'read-only' in msg:
                return dbexceptions.RetryError(new_args)
        return dbexceptions.DatabaseError(new_args)
    elif isinstance(exc, gorpc.ProgrammingError):
        return dbexceptions.ProgrammingError(new_args)
    elif isinstance(exc, gorpc.GoRpcError):
        return dbexceptions.FatalError(new_args)
    return exc
Exemple #6
0
    def execute(self, sql, bind_variables, **kargs):
        self.rowcount = 0
        self.results = None
        self.description = None
        self.lastrowid = None

        sql_check = sql.strip().lower()
        if sql_check == 'begin':
            self.begin()
            return
        elif sql_check == 'commit':
            self.commit()
            return
        elif sql_check == 'rollback':
            self.rollback()
            return

        write_query = bool(write_sql_pattern.match(sql))
        # NOTE: This check may also be done at high-layers but adding it here for completion.
        if write_query:
            if not self.is_writable():
                raise dbexceptions.DatabaseError(
                    'DML on a non-writable cursor', sql)

        self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
            sql,
            bind_variables,
            self.keyspace,
            self.tablet_type,
            keyspace_ids=self.keyspace_ids,
            keyranges=self.keyranges,
            not_in_transaction=(not self.is_writable()))
        self.index = 0
        return self.rowcount
    def convert_to_dbexception(self, args):
        """Converts from a VitessError to the appropriate dbexceptions class.

    Args:
      args: argument tuple to use to create the new exception.

    Returns:
      An exception from dbexceptions.
    """
        # FIXME(alainjobart): this is extremely confusing: self.message is only
        # used for integrity errors, and nothing else. The other cases
        # have to provide the message in the args.
        if self.code == vtrpc_pb2.UNAVAILABLE:
            if throttler_err_re.search(self.message):
                return dbexceptions.ThrottledError(args)
            return dbexceptions.TransientError(args)
        if self.code == vtrpc_pb2.FAILED_PRECONDITION:
            return dbexceptions.QueryNotServed(args)
        if self.code == vtrpc_pb2.ALREADY_EXISTS:
            # Prune the error message to truncate after the mysql errno, since
            # the error message may contain the query string with bind variables.
            msg = self.message.lower()
            parts = self._errno_pattern.split(msg)
            pruned_msg = msg[:msg.find(parts[2])]
            new_args = (pruned_msg, ) + tuple(args[1:])
            return dbexceptions.IntegrityError(new_args)
        if self.code == vtrpc_pb2.INVALID_ARGUMENT:
            return dbexceptions.ProgrammingError(args)
        return dbexceptions.DatabaseError(args)
Exemple #8
0
    def execute_entity_ids(self,
                           sql,
                           bind_variables,
                           entity_keyspace_id_map,
                           entity_column_name,
                           effective_caller_id=None):
        # FIXME: Remove effective_caller_id from interface.
        self._clear_list_state()

        # This is by definition a scatter query, so raise exception.
        write_query = bool(write_sql_pattern.match(sql))
        if write_query:
            raise dbexceptions.DatabaseError(
                'execute_entity_ids is not allowed for write queries')
        # FIXME: Remove effective_caller_id from interface.
        if effective_caller_id is not None:
            self.set_effective_caller_id(effective_caller_id)
        self.results, self.rowcount, self.lastrowid, self.description = (
            self._get_conn()._execute_entity_ids(
                sql,
                bind_variables,
                self.keyspace,
                self.tablet_type,
                entity_keyspace_id_map,
                entity_column_name,
                not_in_transaction=not self.is_writable(),
                effective_caller_id=self.effective_caller_id))
        return self.rowcount
Exemple #9
0
    def execute(self, sql, bind_variables, **kargs):
        self._clear_list_state()
        # FIXME: Remove effective_caller_id from interface.
        effective_caller_id = kargs.get('effective_caller_id')
        if effective_caller_id:
            self.set_effective_caller_id(effective_caller_id)
        if self._handle_transaction_sql(sql):
            return
        write_query = bool(write_sql_pattern.match(sql))
        # NOTE: This check may also be done at higher layers but adding it
        # here for completion.
        if write_query:
            if not self.is_writable():
                raise dbexceptions.DatabaseError(
                    'DML on a non-writable cursor', sql)

        self.results, self.rowcount, self.lastrowid, self.description = (
            self._get_conn()._execute(
                sql,
                bind_variables,
                self.keyspace,
                self.tablet_type,
                keyspace_ids=self.keyspace_ids,
                keyranges=self.keyranges,
                not_in_transaction=not self.is_writable(),
                effective_caller_id=self.effective_caller_id))
        return self.rowcount
Exemple #10
0
  def convert_to_dbexception(self, args):
    """Converts from a TabletError to the appropriate dbexceptions class.

    Args:
      args: argument tuple to use to create the new exception.

    Returns:
      An exception from dbexceptions.
    """
    if self.code == vtrpc_pb2.QUERY_NOT_SERVED:
      return dbexceptions.RetryError(args)

    if self.code == vtrpc_pb2.INTERNAL_ERROR:
      return dbexceptions.FatalError(args)

    if self.code == vtrpc_pb2.RESOURCE_EXHAUSTED:
      return dbexceptions.TxPoolFull(args)

    if self.code == vtrpc_pb2.INTEGRITY_ERROR:
      # Prune the error message to truncate after the mysql errno, since
      # the error message may contain the query string with bind variables.
      msg = self.message.lower()
      parts = self._errno_pattern.split(msg)
      pruned_msg = msg[:msg.find(parts[2])]
      new_args = (pruned_msg,) + tuple(args[1:])
      return dbexceptions.IntegrityError(new_args)

    return dbexceptions.DatabaseError(args)
Exemple #11
0
    def delete_by_columns(class_,
                          cursor,
                          where_column_value_pairs,
                          limit=None):
        sharding_key = cursor.routing.sharding_key
        if sharding_key is None:
            raise dbexceptions.ProgrammingError("sharding_key cannot be empty")

        if not where_column_value_pairs:
            raise dbexceptions.ProgrammingError(
                "deleting the whole table is not allowed")

        query, bind_vars = sql_builder.delete_by_columns_query(
            class_.table_name, where_column_value_pairs, limit=limit)
        cursor.execute(query, bind_vars)
        if cursor.rowcount == 0:
            raise dbexceptions.DatabaseError("DB Row not found")

        rowcount = cursor.rowcount

        #delete the lookup map.
        lookup_cursor_method = functools.partial(
            db_object.create_cursor_from_old_cursor, cursor)
        class_.delete_sharding_key_entity_id_lookup(lookup_cursor_method,
                                                    sharding_key)

        return rowcount
Exemple #12
0
  def execute_entity_ids(
      self, sql, bind_variables, entity_keyspace_id_map, entity_column_name,
      effective_caller_id=None):
    self.rowcount = 0
    self.results = None
    self.description = None
    self.lastrowid = None

    # This is by definition a scatter query, so raise exception.
    write_query = bool(write_sql_pattern.match(sql))
    if write_query:
      raise dbexceptions.DatabaseError(
          'execute_entity_ids is not allowed for write queries')

    self.results, self.rowcount, self.lastrowid, self.description = (
        self._conn._execute_entity_ids(
            sql,
            bind_variables,
            self.keyspace,
            self.tablet_type,
            entity_keyspace_id_map,
            entity_column_name,
            not_in_transaction=not self.is_writable(),
            effective_caller_id=effective_caller_id))
    self.index = 0
    return self.rowcount
Exemple #13
0
  def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):
    if not where_column_value_pairs:
      raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")

    query, bind_vars = class_.create_delete_query(where_column_value_pairs,
                                                  limit=limit)
    cursor.execute(query, bind_vars)
    if cursor.rowcount == 0:
      raise dbexceptions.DatabaseError("DB Row not found")
    return cursor.rowcount
Exemple #14
0
def get_keyrange_from_shard_name(keyspace, shard_name, db_type):
    kr = None
    if not is_sharded_keyspace(keyspace, db_type):
        if shard_name == keyrange_constants.SHARD_ZERO:
            kr = keyrange.KeyRange(keyrange_constants.NON_PARTIAL_KEYRANGE)
        else:
            raise dbexceptions.DatabaseError(
                'Invalid shard_name %s for keyspace %s', shard_name, keyspace)
    else:
        kr = keyrange.KeyRange(shard_name)
    return kr
Exemple #15
0
def get_keyrange(shard_name):
    kr = None
    if shard_name == keyrange_constants.SHARD_ZERO:
        kr = keyrange_constants.NON_PARTIAL_KEYRANGE
    else:
        kr_parts = shard_name.split('-')
        if len(kr_parts) != 2:
            raise dbexceptions.DatabaseError(
                'Invalid shard_name %s for keyspace %s', shard_name, keyspace)
        kr = keyrange.KeyRange(
            (kr_parts[0].decode('hex'), kr_parts[1].decode('hex')))
    return kr
 def stream_next(self):
     try:
         response = self.client.stream_next()
         if response is None:
             return None
         return EventData(response.reply).__dict__
     except gorpc.AppError as e:
         raise dbexceptions.DatabaseError(*e.args)
     except gorpc.GoRpcError as e:
         raise dbexceptions.OperationalError(*e.args)
     except:
         logging.exception('gorpc low-level error')
         raise
Exemple #17
0
    def stream_update(self, position, timeout=3600.0):
        """Note this implementation doesn't honor the timeout."""
        try:
            self.client.stream_call('UpdateStream.ServeUpdateStream',
                                    {'Position': position})
            while True:
                response = self.client.stream_next()
                if response is None:
                    break
                reply = response.reply

                str_category = reply['Category']
                if str_category == 'DML':
                    category = update_stream.StreamEvent.DML
                elif str_category == 'DDL':
                    category = update_stream.StreamEvent.DDL
                elif str_category == 'POS':
                    category = update_stream.StreamEvent.POS
                else:
                    category = update_stream.StreamEvent.ERR

                fields = []
                rows = []
                if reply['PrimaryKeyFields']:
                    conversions = []
                    for field in reply['PrimaryKeyFields']:
                        fields.append(field['Name'])
                        conversions.append(
                            field_types.conversions.get(field['Type']))

                    for pk_list in reply['PrimaryKeyValues']:
                        if not pk_list:
                            continue
                        row = tuple(_make_row(pk_list, conversions))
                        rows.append(row)

                yield update_stream.StreamEvent(
                    category=category,
                    table_name=reply['TableName'],
                    fields=fields,
                    rows=rows,
                    sql=reply['Sql'],
                    timestamp=reply['Timestamp'],
                    transaction_id=reply['TransactionID'])
        except gorpc.AppError as e:
            raise dbexceptions.DatabaseError(*e.args)
        except gorpc.GoRpcError as e:
            raise dbexceptions.OperationalError(*e.args)
        except:
            raise
Exemple #18
0
def _convert_exception(exc, *args, **kwargs):
  """This parses the protocol exceptions to the api interface exceptions.

  This also logs the exception and increments the appropriate error counters.

  Args:
    exc: raw protocol exception.
    *args: additional args from the raising site.
    **kwargs: additional keyword args from the raising site.
              They will be converted into a single string, and added as an extra
              arg to the exception.

  Returns:
    Api interface exceptions - dbexceptions with new args.
  """
  kwargs_as_str = vtgate_utils.convert_exception_kwargs(kwargs)
  exc.args += args
  if kwargs_as_str:
    exc.args += kwargs_as_str,
  new_args = (type(exc).__name__,) + exc.args
  if isinstance(exc, vtgate_utils.VitessError):
    new_exc = exc.convert_to_dbexception(new_args)
  elif isinstance(exc, face.ExpirationError):
    # face.ExpirationError is returned by the gRPC library when
    # a request times out. Note it is a subclass of face.AbortionError
    # so we have to test for it before.
    new_exc = dbexceptions.TimeoutError(new_args)
  elif isinstance(exc, face.AbortionError):
    # face.AbortionError is the toplevel error returned by gRPC for any
    # RPC that finishes earlier than expected.
    msg = exc.details
    if exc.code == interfaces.StatusCode.UNAVAILABLE:
      if _throttler_err_pattern.search(msg):
        return dbexceptions.ThrottledError(new_args)
      else:
        return dbexceptions.TransientError(new_args)
    elif exc.code == interfaces.StatusCode.ALREADY_EXISTS:
      new_exc = _prune_integrity_error(msg, new_args)
    elif exc.code == interfaces.StatusCode.FAILED_PRECONDITION:
      return dbexceptions.QueryNotServed(msg, new_args)
    else:
      # Unhandled RPC application error
      new_exc = dbexceptions.DatabaseError(new_args + (msg,))
  else:
    new_exc = exc
  vtgate_utils.log_exception(
      new_exc,
      keyspace=kwargs.get('keyspace'), tablet_type=kwargs.get('tablet_type'))
  return new_exc
Exemple #19
0
def handle_app_error(exc_args):
    msg = str(exc_args[0]).lower()
    if msg.startswith('request_backlog'):
        return dbexceptions.RequestBacklog(exc_args)
    match = _errno_pattern.search(msg)
    if match:
        mysql_errno = int(match.group(1))
        # Prune the error message to truncate the query string
        # returned by mysql as it contains bind variables.
        if mysql_errno == 1062:
            parts = _errno_pattern.split(msg)
            pruned_msg = msg[:msg.find(parts[2])]
            new_args = (pruned_msg, ) + tuple(exc_args[1:])
            return dbexceptions.IntegrityError(new_args)
    return dbexceptions.DatabaseError(exc_args)
Exemple #20
0
    def execute(self, sql, bind_variables, **kargs):
        self.rowcount = 0
        self.results = None
        self.description = None
        self.lastrowid = None

        sql_check = sql.strip().lower()
        if sql_check == 'begin':
            self.begin()
            return
        elif sql_check == 'commit':
            self.commit()
            return
        elif sql_check == 'rollback':
            self.rollback()
            return

        write_query = bool(write_sql_pattern.match(sql))
        # NOTE: This check may also be done at high-layers but adding it here for completion.
        if write_query:
            if not self.is_writable():
                raise dbexceptions.DatabaseError(
                    'DML on a non-writable cursor', sql)

            # FIXME(shrutip): these checks should be on vtgate server, so
            # dependency on topology can be removed.
            if topology.is_sharded_keyspace(self.keyspace, self.tablet_type):
                if self.keyspace_ids is None or len(self.keyspace_ids) != 1:
                    raise dbexceptions.ProgrammingError(
                        'DML on zero or multiple keyspace ids is not allowed: %r'
                        % self.keyspace_ids)
            else:
                if not self.keyranges or str(
                        self.keyranges[0]
                ) != keyrange_constants.NON_PARTIAL_KEYRANGE:
                    raise dbexceptions.ProgrammingError(
                        'Keyrange not correct for non-sharded keyspace: %r' %
                        self.keyranges)

        self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
            sql,
            bind_variables,
            self.keyspace,
            self.tablet_type,
            keyspace_ids=self.keyspace_ids,
            keyranges=self.keyranges)
        self.index = 0
        return self.rowcount
Exemple #21
0
  def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None,
                        **columns):
    if not where_column_value_pairs:
      where_column_value_pairs = columns.items()
      where_column_value_pairs.sort()

    if not where_column_value_pairs:
      raise dbexceptions.ProgrammingError("deleting the whole table is not allowed")

    query, bind_variables = sql_builder.delete_by_columns_query(class_.table_name,
                                                              where_column_value_pairs,
                                                              limit=limit)
    cursor.execute(query, bind_variables)
    if cursor.rowcount == 0:
      raise dbexceptions.DatabaseError("DB Row not found")
    return cursor.rowcount
  def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None):

    if not where_column_value_pairs:
      raise dbexceptions.ProgrammingError(
          'deleting the whole table is not allowed')

    where_column_value_pairs = class_._add_keyspace_id(
        unpack_keyspace_id(cursor.keyspace_ids[0]), where_column_value_pairs)

    query, bind_vars = sql_builder.delete_by_columns_query(
        class_.table_name, where_column_value_pairs, limit=limit)
    cursor.execute(query, bind_vars)
    if cursor.rowcount == 0:
      raise dbexceptions.DatabaseError('DB Row not found')

    return cursor.rowcount
def convert_exception(exc, *args):
    new_args = exc.args + args
    if isinstance(exc, gorpc.TimeoutError):
        return TimeoutError(new_args)
    elif isinstance(exc, gorpc.AppError):
        msg = str(exc[0]).lower()
        match = _errno_pattern.search(msg)
        if match:
            mysql_errno = int(match.group(1))
            return _errno_map.get(mysql_errno,
                                  dbexceptions.DatabaseError)(new_args)
        return dbexceptions.DatabaseError(new_args)
    elif isinstance(exc, gorpc.ProgrammingError):
        return dbexceptions.ProgrammingError(new_args)
    elif isinstance(exc, gorpc.GoRpcError):
        return FatalError(new_args)
    return exc
Exemple #24
0
def convert_exception(exc, *args):
    new_args = exc.args + args
    vtdb_logger.get_logger().vtgatev2_exception(exc)
    if isinstance(exc, gorpc.TimeoutError):
        return dbexceptions.TimeoutError(new_args)
    elif isinstance(exc, gorpc.AppError):
        msg = str(exc[0]).lower()
        match = _errno_pattern.search(msg)
        if match:
            mysql_errno = int(match.group(1))
            if mysql_errno == 1062:
                return dbexceptions.IntegrityError(new_args)
        return dbexceptions.DatabaseError(new_args)
    elif isinstance(exc, gorpc.ProgrammingError):
        return dbexceptions.ProgrammingError(new_args)
    elif isinstance(exc, gorpc.GoRpcError):
        return dbexceptions.FatalError(new_args)
    return exc
Exemple #25
0
def convert_exception(exc, *args):
    new_args = exc.args + args
    if isinstance(exc, gorpc.TimeoutError):
        return dbexceptions.TimeoutError(new_args)
    elif isinstance(exc, gorpc.AppError):
        msg = str(exc[0]).lower()
        if msg.startswith('request_backlog'):
            return dbexceptions.RequestBacklog(new_args)
        match = _errno_pattern.search(msg)
        if match:
            mysql_errno = int(match.group(1))
            if mysql_errno == 1062:
                return dbexceptions.IntegrityError(new_args)
        return dbexceptions.DatabaseError(new_args)
    elif isinstance(exc, gorpc.ProgrammingError):
        return dbexceptions.ProgrammingError(new_args)
    elif isinstance(exc, gorpc.GoRpcError):
        return dbexceptions.FatalError(new_args)
    return exc
  def lookup_sharding_key_from_entity_id(
      class_, cursor_method, entity_id_column, entity_id):
    """This method is used to map any entity id to sharding key.

    Args:
      cursor_method: Cursor method.
      entity_id_column: Non-sharding key indexes that can be used for query
        routing.
      entity_id: entity id value.

    Returns:
      sharding key to be used for routing.
    """
    entity_lookup_column = class_.get_lookup_column_name(entity_id_column)
    lookup_class = class_.entity_id_lookup_map[entity_id_column]
    rows = lookup_class.get(cursor_method, entity_lookup_column, entity_id)

    entity_id_sharding_key_map = {}
    if not rows:
      # return entity_id_sharding_key_map
      raise dbexceptions.DatabaseError('LookupRow not found')

    if class_.sharding_key_column_name is not None:
      sk_lookup_column = class_.get_lookup_column_name(
          class_.sharding_key_column_name)
    else:
      # This is needed since the table may not have a sharding key column name
      # but the lookup map will have it.
      lookup_column_names = rows[0].keys()
      if len(lookup_column_names) != 2:
        raise dbexceptions.ProgrammingError(
            'lookup table has more than two columns.')
      sk_lookup_column = list(
          set(lookup_column_names) - set(list(entity_lookup_column)))[0]
    for row in rows:
      en_id = row[entity_lookup_column]
      sk = row[sk_lookup_column]
      entity_id_sharding_key_map[en_id] = sk

    return entity_id_sharding_key_map
Exemple #27
0
def convert_exception(exc, *args):
    new_args = exc.args + args
    if isinstance(exc, gorpc.TimeoutError):
        return dbexceptions.TimeoutError(new_args)
    elif isinstance(exc, gorpc.AppError):
        msg = str(exc[0]).lower()
        if msg.startswith('retry'):
            return dbexceptions.RetryError(new_args)
        if msg.startswith('fatal'):
            return dbexceptions.FatalError(new_args)
        if msg.startswith('tx_pool_full'):
            return dbexceptions.TxPoolFull(new_args)
        match = _errno_pattern.search(msg)
        if match:
            mysql_errno = int(match.group(1))
            return _errno_map.get(mysql_errno,
                                  dbexceptions.DatabaseError)(new_args)
        return dbexceptions.DatabaseError(new_args)
    elif isinstance(exc, gorpc.ProgrammingError):
        return dbexceptions.ProgrammingError(new_args)
    elif isinstance(exc, gorpc.GoRpcError):
        return dbexceptions.FatalError(new_args)
    return exc
Exemple #28
0
    def convert_to_dbexception(self, args):
        """Converts from a VitessError to the appropriate dbexceptions class.

    Args:
      args: argument tuple to use to create the new exception.

    Returns:
      An exception from dbexceptions.
    """
        # FIXME(alainjobart): this is extremely confusing: self.message is only
        # used for integrity errors, and nothing else. The other cases
        # have to provide the message in the args.
        if self.code == vtrpc_pb2.TRANSIENT_ERROR:
            return dbexceptions.TransientError(args)
        if self.code == vtrpc_pb2.INTEGRITY_ERROR:
            # Prune the error message to truncate after the mysql errno, since
            # the error message may contain the query string with bind variables.
            msg = self.message.lower()
            parts = self._errno_pattern.split(msg)
            pruned_msg = msg[:msg.find(parts[2])]
            new_args = (pruned_msg, ) + tuple(args[1:])
            return dbexceptions.IntegrityError(new_args)

        return dbexceptions.DatabaseError(args)