예제 #1
0
    def test_all_possible_actions(self):  # pylint: disable=no-self-use
        world = AtisWorld([("give me all flights from boston to "
                            "philadelphia next week arriving after lunch")])
        possible_actions = world.all_possible_actions()

        assert possible_actions == \
            ['agg -> [agg_func, "(", col_ref, ")"]',
             'agg_func -> ["COUNT"]',
             'agg_func -> ["MAX"]',
             'agg_func -> ["MIN"]',
             'agg_results -> ["(", "SELECT", distinct, agg, "FROM", table_name, '
             'where_clause, ")"]',
             'agg_results -> ["SELECT", distinct, agg, "FROM", table_name, where_clause]',
             'biexpr -> [col_ref, "LIKE", string]',
             'biexpr -> [col_ref, binaryop, value]',
             'biexpr -> [value, binaryop, value]',
             'binaryop -> ["*"]',
             'binaryop -> ["+"]',
             'binaryop -> ["-"]',
             'binaryop -> ["/"]',
             'binaryop -> ["<"]',
             'binaryop -> ["<="]',
             'binaryop -> ["="]',
             'binaryop -> [">"]',
             'binaryop -> [">="]',
             'binaryop -> ["IS"]',
             'boolean -> ["false"]',
             'boolean -> ["true"]',
             'col_ref -> ["*"]',
             'col_ref -> ["aircraft", ".", "aircraft_code"]',
             'col_ref -> ["aircraft", ".", "aircraft_description"]',
             'col_ref -> ["aircraft", ".", "basic_type"]',
             'col_ref -> ["aircraft", ".", "manufacturer"]',
             'col_ref -> ["aircraft", ".", "pressurized"]',
             'col_ref -> ["aircraft", ".", "propulsion"]',
             'col_ref -> ["aircraft", ".", "wide_body"]',
             'col_ref -> ["airline", ".", "airline_code"]',
             'col_ref -> ["airline", ".", "airline_name"]',
             'col_ref -> ["airport", ".", "airport_code"]',
             'col_ref -> ["airport", ".", "airport_location"]',
             'col_ref -> ["airport", ".", "airport_name"]',
             'col_ref -> ["airport", ".", "country_name"]',
             'col_ref -> ["airport", ".", "minimum_connect_time"]',
             'col_ref -> ["airport", ".", "state_code"]',
             'col_ref -> ["airport", ".", "time_zone_code"]',
             'col_ref -> ["airport_service", ".", "airport_code"]',
             'col_ref -> ["airport_service", ".", "city_code"]',
             'col_ref -> ["airport_service", ".", "direction"]',
             'col_ref -> ["airport_service", ".", "miles_distant"]',
             'col_ref -> ["airport_service", ".", "minutes_distant"]',
             'col_ref -> ["city", ".", "city_code"]',
             'col_ref -> ["city", ".", "city_name"]',
             'col_ref -> ["city", ".", "country_name"]',
             'col_ref -> ["city", ".", "state_code"]',
             'col_ref -> ["city", ".", "time_zone_code"]',
             'col_ref -> ["class_of_service", ".", "booking_class"]',
             'col_ref -> ["class_of_service", ".", "class_description"]',
             'col_ref -> ["class_of_service", ".", "rank"]',
             'col_ref -> ["date_day", ".", "day_name"]',
             'col_ref -> ["date_day", ".", "day_number"]',
             'col_ref -> ["date_day", ".", "month_number"]',
             'col_ref -> ["date_day", ".", "year"]',
             'col_ref -> ["days", ".", "day_name"]',
             'col_ref -> ["days", ".", "days_code"]',
             'col_ref -> ["equipment_sequence", ".", "aircraft_code"]',
             'col_ref -> ["equipment_sequence", ".", "aircraft_code_sequence"]',
             'col_ref -> ["fare", ".", "fare_airline"]',
             'col_ref -> ["fare", ".", "fare_basis_code"]',
             'col_ref -> ["fare", ".", "fare_id"]',
             'col_ref -> ["fare", ".", "from_airport"]',
             'col_ref -> ["fare", ".", "one_direction_cost"]',
             'col_ref -> ["fare", ".", "restriction_code"]',
             'col_ref -> ["fare", ".", "round_trip_cost"]',
             'col_ref -> ["fare", ".", "round_trip_required"]',
             'col_ref -> ["fare", ".", "to_airport"]',
             'col_ref -> ["fare_basis", ".", "basis_days"]',
             'col_ref -> ["fare_basis", ".", "booking_class"]',
             'col_ref -> ["fare_basis", ".", "class_type"]',
             'col_ref -> ["fare_basis", ".", "discounted"]',
             'col_ref -> ["fare_basis", ".", "economy"]',
             'col_ref -> ["fare_basis", ".", "fare_basis_code"]',
             'col_ref -> ["fare_basis", ".", "night"]',
             'col_ref -> ["fare_basis", ".", "premium"]',
             'col_ref -> ["fare_basis", ".", "season"]',
             'col_ref -> ["flight", ".", "aircraft_code_sequence"]',
             'col_ref -> ["flight", ".", "airline_code"]',
             'col_ref -> ["flight", ".", "airline_flight"]',
             'col_ref -> ["flight", ".", "arrival_time"]',
             'col_ref -> ["flight", ".", "connections"]',
             'col_ref -> ["flight", ".", "departure_time"]',
             'col_ref -> ["flight", ".", "dual_carrier"]',
             'col_ref -> ["flight", ".", "flight_days"]',
             'col_ref -> ["flight", ".", "flight_id"]',
             'col_ref -> ["flight", ".", "flight_number"]',
             'col_ref -> ["flight", ".", "from_airport"]',
             'col_ref -> ["flight", ".", "meal_code"]',
             'col_ref -> ["flight", ".", "stops"]',
             'col_ref -> ["flight", ".", "time_elapsed"]',
             'col_ref -> ["flight", ".", "to_airport"]',
             'col_ref -> ["flight_fare", ".", "fare_id"]',
             'col_ref -> ["flight_fare", ".", "flight_id"]',
             'col_ref -> ["flight_leg", ".", "flight_id"]',
             'col_ref -> ["flight_leg", ".", "leg_flight"]',
             'col_ref -> ["flight_leg", ".", "leg_number"]',
             'col_ref -> ["flight_stop", ".", "arrival_airline"]',
             'col_ref -> ["flight_stop", ".", "arrival_flight_number"]',
             'col_ref -> ["flight_stop", ".", "arrival_time"]',
             'col_ref -> ["flight_stop", ".", "departure_airline"]',
             'col_ref -> ["flight_stop", ".", "departure_flight_number"]',
             'col_ref -> ["flight_stop", ".", "departure_time"]',
             'col_ref -> ["flight_stop", ".", "flight_id"]',
             'col_ref -> ["flight_stop", ".", "stop_airport"]',
             'col_ref -> ["flight_stop", ".", "stop_days"]',
             'col_ref -> ["flight_stop", ".", "stop_number"]',
             'col_ref -> ["flight_stop", ".", "stop_time"]',
             'col_ref -> ["food_service", ".", "compartment"]',
             'col_ref -> ["food_service", ".", "meal_code"]',
             'col_ref -> ["food_service", ".", "meal_description"]',
             'col_ref -> ["food_service", ".", "meal_number"]',
             'col_ref -> ["ground_service", ".", "airport_code"]',
             'col_ref -> ["ground_service", ".", "city_code"]',
             'col_ref -> ["ground_service", ".", "ground_fare"]',
             'col_ref -> ["ground_service", ".", "transport_type"]',
             'col_ref -> ["month", ".", "month_name"]',
             'col_ref -> ["month", ".", "month_number"]',
             'col_ref -> ["restriction", ".", "advance_purchase"]',
             'col_ref -> ["restriction", ".", "application"]',
             'col_ref -> ["restriction", ".", "maximum_stay"]',
             'col_ref -> ["restriction", ".", "minimum_stay"]',
             'col_ref -> ["restriction", ".", "no_discounts"]',
             'col_ref -> ["restriction", ".", "restriction_code"]',
             'col_ref -> ["restriction", ".", "saturday_stay_required"]',
             'col_ref -> ["restriction", ".", "stopovers"]',
             'col_ref -> ["state", ".", "country_name"]',
             'col_ref -> ["state", ".", "state_code"]',
             'col_ref -> ["state", ".", "state_name"]',
             'col_refs -> [col_ref, ",", col_refs]',
             'col_refs -> [col_ref]',
             'condition -> [biexpr]',
             'condition -> [in_clause]',
             'condition -> [ternaryexpr]',
             'conditions -> ["(", conditions, ")", conj, conditions]',
             'conditions -> ["(", conditions, ")"]',
             'conditions -> ["NOT", conditions]',
             'conditions -> [condition, conj, "(", conditions, ")"]',
             'conditions -> [condition, conj, conditions]',
             'conditions -> [condition]',
             'conj -> ["AND"]',
             'conj -> ["OR"]',
             'distinct -> [""]',
             'distinct -> ["DISTINCT"]',
             'in_clause -> [col_ref, "IN", query]',
             'number -> ["0"]',
             'number -> ["1"]',
             'number -> ["1200"]',
             'number -> ["1400"]',
             'number -> ["1800"]',
             'pos_value -> ["ALL", query]',
             'pos_value -> ["ANY", query]',
             'pos_value -> ["NULL"]',
             'pos_value -> [agg_results]',
             'pos_value -> [boolean]',
             'pos_value -> [col_ref]',
             'pos_value -> [number]',
             'pos_value -> [string]',
             'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, '
             'where_clause, ")"]',
             'query -> ["SELECT", distinct, select_results, "FROM", table_refs, '
             'where_clause]',
             'select_results -> [agg]',
             'select_results -> [col_refs]',
             'statement -> [query, ";"]',
             'string -> ["\'BBOS\'"]',
             'string -> ["\'BOS\'"]',
             'string -> ["\'BOSTON\'"]',
             'string -> ["\'LUNCH\'"]',
             'string -> ["\'PHILADELPHIA\'"]',
             'string -> ["\'PHL\'"]',
             'string -> ["\'PPHL\'"]',
             'table_name -> ["aircraft"]',
             'table_name -> ["airline"]',
             'table_name -> ["airport"]',
             'table_name -> ["airport_service"]',
             'table_name -> ["city"]',
             'table_name -> ["class_of_service"]',
             'table_name -> ["date_day"]',
             'table_name -> ["days"]',
             'table_name -> ["equipment_sequence"]',
             'table_name -> ["fare"]',
             'table_name -> ["fare_basis"]',
             'table_name -> ["flight"]',
             'table_name -> ["flight_fare"]',
             'table_name -> ["flight_leg"]',
             'table_name -> ["flight_stop"]',
             'table_name -> ["food_service"]',
             'table_name -> ["ground_service"]',
             'table_name -> ["month"]',
             'table_name -> ["restriction"]',
             'table_name -> ["state"]',
             'table_refs -> [table_name, ",", table_refs]',
             'table_refs -> [table_name]',
             'ternaryexpr -> [col_ref, "BETWEEN", value, "AND", value]',
             'ternaryexpr -> [col_ref, "NOT", "BETWEEN", value, "AND", value]',
             'value -> ["NOT", pos_value]',
             'value -> [pos_value]',
             'where_clause -> ["WHERE", "(", conditions, ")"]',
             'where_clause -> ["WHERE", conditions]']
예제 #2
0
    def text_to_instance(
            self,  # type: ignore
            utterances: List[str],
            sql_query: str = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query: ``str``, optional
            The SQL query, given as label during training or validation.
        """
        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances,
                          database_directory=self._database_directory)

        if sql_query:
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            lhs, _ = production_rule.split(' ->')
            is_global_rule = not lhs in ['number', 'string']
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([
                token for token in production_rule.split(' ') if token != 'ws'
            ])
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(action_field.field_list)
        }
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {
            'utterance': utterance_field,
            'actions': action_field,
            'world': world_field,
            'linking_scores': ArrayField(world.linking_scores)
        }

        if sql_query:
            if action_sequence:
                for production_rule in action_sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))

                action_sequence_field: List[Field] = []
                action_sequence_field.append(ListField(index_fields))
                fields['target_action_sequence'] = ListField(
                    action_sequence_field)
            else:
                # If we are given a SQL query, but we are unable to parse it, then we will skip it.
                return None

        return Instance(fields)
예제 #3
0
    def text_to_instance(
            self,  # type: ignore
            utterances: List[str],
            sql_query_labels: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query_labels: ``List[str]``, optional
            The SQL queries that are given as labels during training or validation.
        """
        if self._num_turns_to_concatenate:
            utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join(
                utterances[-self._num_turns_to_concatenate:])

        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances)

        if sql_query_labels:
            # If there are multiple sql queries given as labels, we use the shortest
            # one for training.
            sql_query = min(sql_query_labels, key=len)
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            nonterminal, _ = production_rule.split(' ->')
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([
                token for token in production_rule.split(' ') if token != 'ws'
            ])
            field = ProductionRuleField(production_rule,
                                        self._is_global_rule(nonterminal))
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(action_field.field_list)
        }
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {
            'utterance': utterance_field,
            'actions': action_field,
            'world': world_field,
            'linking_scores': ArrayField(world.linking_scores)
        }

        if sql_query_labels != None:
            fields['sql_queries'] = MetadataField(sql_query_labels)
            if action_sequence and not self._keep_if_unparseable:
                for production_rule in action_sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_field = ListField(index_fields)
                fields['target_action_sequence'] = action_sequence_field
            elif not self._keep_if_unparseable:
                # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
                # to keep it, then we will skip the it.
                return None

        return Instance(fields)
예제 #4
0
파일: atis.py 프로젝트: apmoore1/allennlp
    def text_to_instance(self,  # type: ignore
                         utterances: List[str],
                         sql_query_labels: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query_labels: ``List[str]``, optional
            The SQL queries that are given as labels during training or validation.
        """
        if self._num_turns_to_concatenate:
            utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join(utterances[-self._num_turns_to_concatenate:])

        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances)

        if sql_query_labels:
            # If there are multiple sql queries given as labels, we use the shortest
            # one for training.
            sql_query = min(sql_query_labels, key=len)
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                action_sequence = []
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            nonterminal, _ = production_rule.split(' ->')
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([token for token in production_rule.split(' ') if token != 'ws'])
            field = ProductionRuleField(production_rule, self._is_global_rule(nonterminal))
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {action.rule: i # type: ignore
                      for i, action in enumerate(action_field.field_list)}
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {'utterance' : utterance_field,
                  'actions' : action_field,
                  'world' : world_field,
                  'linking_scores' : ArrayField(world.linking_scores)}

        if sql_query_labels != None:
            fields['sql_queries'] = MetadataField(sql_query_labels)
            if self._keep_if_unparseable or action_sequence:
                for production_rule in action_sequence:
                    index_fields.append(IndexField(action_map[production_rule], action_field))
                if not action_sequence:
                    index_fields = [IndexField(-1, action_field)]
                action_sequence_field = ListField(index_fields)
                fields['target_action_sequence'] = action_sequence_field
            else:
                # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
                # to keep it, then we will skip the it.
                return None

        return Instance(fields)