Beispiel #1
0
    def find_and_add_copy_from_text(substr):
        """Finds a substring in the utterance and adds a copying action."""
        # It's a substring of the original utterance, but not so sure it could be
        # composed of wordpieces.
        found = False
        start_wordpiece = -1
        end_wordpiece = -1
        for i in range(len(example.model_input.utterance_wordpieces)):
            for j in range(i + 1,
                           len(example.model_input.utterance_wordpieces) + 1):
                # Compose a string. If it has ##, that means it's a wordpiece, so should
                # not have a space in front.
                composed_pieces = ' '.join([
                    wordpiece.wordpiece for wordpiece in
                    example.model_input.utterance_wordpieces[i:j]
                ]).replace(' ##', '')
                if substr.lower() == composed_pieces:
                    start_wordpiece = i
                    end_wordpiece = j

                    found = True
                    break
            if found:
                break

        if start_wordpiece >= 0 and end_wordpiece >= 0:
            # Found wordpiece(s, when put together) comprising this item
            for i in range(start_wordpiece, end_wordpiece):
                action = SQLAction(
                    utterance_copy=example.model_input.utterance_wordpieces[i])
                example.gold_sql_query.actions.append(action)
            return True
        return False
Beispiel #2
0
def _parse_identifier(sql, example, anonymize_values):
  """Parse the part relative to an Identifier in the SQL query."""
  for item in sql:
    if item.ttype == sqlparse.tokens.Text.Whitespace:
      continue

    if _is_identifier(item):
      _parse_identifier(item, example, anonymize_values)
      continue

    if _is_order(item):
      _add_simple_step(item, example)
      continue

    if (_is_table_alias(item) or (_is_punctuation(item) and item.value == '.')):
      _add_simple_step(item, example)
      continue

    if _is_keyword(item) and item.value in ('as',):
      _add_simple_step(item, example)
      continue

    if _is_name(item):
      entity = _find_simple_entity(item.value, example)
      if entity is not None:
        schema_copy_action = None

        if isinstance(entity, DatabaseTable):
          schema_copy_action = SchemaEntityCopy(copied_table=entity)
        elif isinstance(entity, TableColumn):
          schema_copy_action = SchemaEntityCopy(copied_column=entity)
        else:
          raise ValueError('Type of entity is unexpected: ' + str(type(entity)))

        copy_action = SQLAction(entity_copy=schema_copy_action)
        example.gold_sql_query.actions.append(copy_action)
      else:
        try:
          _resolve_reference(item, example)
        except AttributeError as e:
          # Generally this means the reference wasn't found i.e., in WikiSQL, a
          # value didn't have quotes, so just add it as a value
          print(e)
          _add_simple_value(item, example, anonymize_values)
      continue

    if _is_literal(item):
      prev_len = len(example.gold_sql_query.actions)
      _add_simple_value(item, example, anonymize_values)
      if len(example.gold_sql_query.actions) == prev_len:
        raise ValueError(
            'Gold query did not change length when adding simple value!')
      continue

    _debug_state(item, example)
    raise ParseError('Incomplete _parse_identifier')
Beispiel #3
0
  def _find_direction(reverse):
    """Finds a column annotation in a given direction."""
    table_entities = _find_from_table(item, example, reverse=reverse)

    if not table_entities:
      return False

    for table_entity in table_entities:
      entity = _find_column_entities(item.value, example, table_entity)

      if entity:
        example.gold_sql_query.actions.append(
            SQLAction(entity_copy=SchemaEntityCopy(copied_column=entity)))
        return True

    raise ParseError('Unable to find annotation of table ' + str(item))
Beispiel #4
0
def _add_value_literal(item_str, example, anonymize):
    """Adds a value action to the output."""

    if anonymize:
        example.gold_sql_query.actions.append(SQLAction(symbol='"'))
        example.gold_sql_query.actions.append(SQLAction(symbol='value'))
        example.gold_sql_query.actions.append(SQLAction(symbol='"'))
        return True

    # Add quotes if [1] there aren't quotes and [2] it's not numeric
    if not item_str.replace('.', '', 1).isdigit() and item_str.count(
            '"') < 2 and item_str.count('\'') < 2:
        item_str = "'" + item_str + "'"

    def find_and_add_copy_from_text(substr):
        """Finds a substring in the utterance and adds a copying action."""
        # It's a substring of the original utterance, but not so sure it could be
        # composed of wordpieces.
        found = False
        start_wordpiece = -1
        end_wordpiece = -1
        for i in range(len(example.model_input.utterance_wordpieces)):
            for j in range(i + 1,
                           len(example.model_input.utterance_wordpieces) + 1):
                # Compose a string. If it has ##, that means it's a wordpiece, so should
                # not have a space in front.
                composed_pieces = ' '.join([
                    wordpiece.wordpiece for wordpiece in
                    example.model_input.utterance_wordpieces[i:j]
                ]).replace(' ##', '')
                if substr.lower() == composed_pieces:
                    start_wordpiece = i
                    end_wordpiece = j

                    found = True
                    break
            if found:
                break

        if start_wordpiece >= 0 and end_wordpiece >= 0:
            # Found wordpiece(s, when put together) comprising this item
            for i in range(start_wordpiece, end_wordpiece):
                action = SQLAction(
                    utterance_copy=example.model_input.utterance_wordpieces[i])
                example.gold_sql_query.actions.append(action)
            return True
        return False

    # First, check if this string could be copied from the wordpiece-tokenized
    # inputs.
    quote_type = ''
    if item_str.lower() in example.model_input.original_utterance.lower():
        success = find_and_add_copy_from_text(item_str)

        if not success or item_str in VALID_GENERATED_TOKENS:
            example.gold_sql_query.actions.append(SQLAction(symbol=item_str))

        return success or item_str in VALID_GENERATED_TOKENS

    elif item_str.startswith('\'') and item_str.endswith('\''):
        quote_type = '\''
    elif item_str.startswith('"') and item_str.endswith('"'):
        quote_type = '"'

    if quote_type:
        if item_str[1:-1].lower(
        ) in example.model_input.original_utterance.lower():
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            success = find_and_add_copy_from_text(item_str[1:-1])
            if not success or item_str in VALID_GENERATED_TOKENS:
                example.gold_sql_query.actions.append(
                    SQLAction(symbol=item_str))

            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            return success or item_str in VALID_GENERATED_TOKENS
        elif item_str[1] == '%' and item_str[-2] == '%' and item_str[
                2:-2].lower() in example.model_input.original_utterance.lower(
                ):
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))
            example.gold_sql_query.actions.append(SQLAction(symbol='%'))

            success = find_and_add_copy_from_text(item_str[2:-2])
            if not success or item_str in VALID_GENERATED_TOKENS:
                example.gold_sql_query.actions.append(
                    SQLAction(symbol=item_str[2:-2]))

            example.gold_sql_query.actions.append(SQLAction(symbol='%'))
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            return success or item_str in VALID_GENERATED_TOKENS

    # Just add it as choice from the output vocabulary
    if u'u s a' in item_str:
        raise ValueError('WHAT????????')
    example.gold_sql_query.actions.append(SQLAction(symbol=item_str))

    return item_str in VALID_GENERATED_TOKENS
Beispiel #5
0
def _add_generate_action(token, example):
    example.gold_sql_query.actions.append(SQLAction(symbol=token))
Beispiel #6
0
def _add_column_copy(table_name, column_name, example):
    column = _find_column(table_name, column_name, example)
    example.gold_sql_query.actions.append(
        SQLAction(entity_copy=SchemaEntityCopy(copied_column=column)))
Beispiel #7
0
def _add_table_copy(table_name, example):
    table = _find_table(table_name, example)
    example.gold_sql_query.actions.append(
        SQLAction(entity_copy=SchemaEntityCopy(copied_table=table)))
Beispiel #8
0
def _resolve_reference(item, example):
  """Resolves an ambiguous token that matches multiple annotations.

  Args:
    item: position in the SQL parse where the search will start.
    example: the QuikExample containing table and column annotations.

  Raises:
    ParseError: if the ambiguity cannot be resolved.
  """
  prev_symbol = example.gold_sql_query.actions[
      len(example.gold_sql_query.actions) - 1].symbol

  if prev_symbol in ('join', 'from'):
    table_annotation = _find_table_annotation(item, example)
    assert table_annotation, ('Cannot find a table annotation for item %s' %
                              item.value)

    example.gold_sql_query.actions.append(
        SQLAction(entity_copy=SchemaEntityCopy(copied_table=table_annotation)))
    return

  parent = item.parent

  # We try the simple case, that is aliases
  parent_tokens = _get_tokens(parent)
  if _is_table_alias(parent_tokens[0]) and parent_tokens[1].value == '.':
    aliases = _get_all_aliases(item, example)

    table_annotation = aliases[parent.tokens[0].value.lower()]
    found_column = _find_column_entities(item.value, example, table_annotation)

    assert found_column, 'Could not find column with name ' + str(item.value)
    example.gold_sql_query.actions.append(
        SQLAction(entity_copy=SchemaEntityCopy(copied_column=found_column)))
    return

  def _find_direction(reverse):
    """Finds a column annotation in a given direction."""
    table_entities = _find_from_table(item, example, reverse=reverse)

    if not table_entities:
      return False

    for table_entity in table_entities:
      entity = _find_column_entities(item.value, example, table_entity)

      if entity:
        example.gold_sql_query.actions.append(
            SQLAction(entity_copy=SchemaEntityCopy(copied_column=entity)))
        return True

    raise ParseError('Unable to find annotation of table ' + str(item))

  if (prev_symbol in ('where', 'by') or _is_where(parent) or
      _is_where(parent.parent)):
    if _find_direction(reverse=True):
      return

  if _find_direction(reverse=False):
    return

  if _find_direction(reverse=True):
    return

  raise ParseError('Unable to find final annotation in any direction')
Beispiel #9
0
def _add_simple_step(item, example):
  example.gold_sql_query.actions.append(SQLAction(symbol=item.value.lower()))
Beispiel #10
0
def _add_simple_value(item, example: NLToSQLExample, anonymize: bool) -> bool:
    """
    Adds a value action to the output.

    Args:
        item: A string value present in the SQL query.
        example: The NLToSQLExample being constructed.
        anonymize: Whether to anonymize values (i.e., replace them with a 'value' placeholder).  Returns a boolean indicating whether the value could be copied from the input.

    Returns:
        Successful copy.
    """
    if anonymize:
        example.gold_sql_query.actions.append(SQLAction(symbol="value"))
        return True

    # Commenting out the code that anonymizes.
    item_str = str(item)

    # Add quotes if [1] there aren't quotes and [2] it's not numeric
    if (not item_str.replace(".", "", 1).isdigit() and item_str.count('"') < 2
            and item_str.count("'") < 2):
        item_str = "'" + item_str + "'"

    def find_and_add_copy_from_text(substr: str):
        """Finds a substring in the utterance and adds a copying action."""
        # It's a substring of the original utterance, but not so sure it could be
        # composed of wordpieces.
        found = False
        start_wordpiece = -1
        end_wordpiece = -1
        for i in range(len(example.model_input.utterance_wordpieces)):
            for j in range(i + 1,
                           len(example.model_input.utterance_wordpieces) + 1):
                # Compose a string. If it has ##, that means it's a wordpiece, so should
                # not have a space in front.
                composed_pieces = (" ".join([
                    piece.wordpiece
                    for piece in example.model_input.utterance_wordpieces[i:j]
                ]).replace(" ##",
                           "").replace(" .",
                                       ".").replace(" - ",
                                                    "-").replace(" !", "!"))
                if substr.lower() == composed_pieces:
                    start_wordpiece = i
                    end_wordpiece = j
                    found = True
                    break
            if found:
                break

        if start_wordpiece >= 0 and end_wordpiece >= 0:
            # Found wordpiece(s, when put together) comprising this item
            for i in range(start_wordpiece, end_wordpiece):
                action = SQLAction(
                    utterance_copy=example.model_input.utterance_wordpieces[i])
                example.gold_sql_query.actions.append(action)
            return True
        return False

    # First, check if this string could be copied from the wordpiece-tokenized inputs.
    quote_type = ""
    if item_str.lower() in example.model_input.original_utterance.lower():
        success = find_and_add_copy_from_text(item_str)

        # QUESTION(samuelstevens): why do they use "not" success? Shouldn't they add it only if it was successfully found?
        # ANSWER(samuelstevens): Because find_and_add_copy_from_text already modifies example (appends an action)
        if not success or _is_valid_generated_token(item_str):
            example.gold_sql_query.actions.append(SQLAction(symbol=item_str))

        return success or _is_valid_generated_token(item_str)

    elif item_str.startswith("'") and item_str.endswith("'"):
        quote_type = "'"
    elif item_str.startswith('"') and item_str.endswith('"'):
        quote_type = '"'

    if quote_type:
        if item_str[1:-1].lower(
        ) in example.model_input.original_utterance.lower():
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            success = find_and_add_copy_from_text(item_str[1:-1])
            if not success or _is_valid_generated_token(item_str):
                example.gold_sql_query.actions.append(
                    SQLAction(symbol=item_str))

            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            return success or _is_valid_generated_token(item_str)
        elif (item_str[1] == "%" and item_str[-2] == "%"
              and item_str[2:-2].lower()
              in example.model_input.original_utterance.lower()):
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))
            example.gold_sql_query.actions.append(SQLAction(symbol="%"))

            success = find_and_add_copy_from_text(item_str[2:-2])
            if not success or _is_valid_generated_token(item_str):
                example.gold_sql_query.actions.append(
                    SQLAction(symbol=item_str[2:-2]))

            example.gold_sql_query.actions.append(SQLAction(symbol="%"))
            example.gold_sql_query.actions.append(SQLAction(symbol=quote_type))

            return success or _is_valid_generated_token(item_str)

    # Just add it as choice from the output vocabulary
    if "u s a" in item_str:
        raise ValueError("WHAT????????")

    example.gold_sql_query.actions.append(SQLAction(symbol=item_str))

    # A value of 1 is used for things like LIMIT 1 when ordering.
    return _is_valid_generated_token(item_str)