def filter_not_in(self, row_expression_list: NestedList, column_name: str,
                   value_expression: NestedList) -> List[Dict[str, str]]:
     """
     Takes a list of rows, a column, and a string value and returns all the rows where the value
     in that column does not contain the given string.
     """
     row_list = self._handle_expression(row_expression_list)
     if not row_list:
         return []
     expression_evaluation = self._handle_expression(value_expression)
     if isinstance(expression_evaluation, list):
         filter_value = expression_evaluation[0]
     elif isinstance(expression_evaluation, str):
         filter_value = expression_evaluation
     else:
         raise ExecutionError(
             f"Unexprected filter value for filter_in: {value_expression}")
     if not isinstance(filter_value, str):
         raise ExecutionError(
             f"Unexprected filter value for filter_in: {value_expression}")
     # Assuming filter value has underscores for spaces. The cell values also have underscores
     # for spaces, so we do not need to replace them here.
     result_list = []
     for row in row_list:
         if filter_value not in row[column_name]:
             result_list.append(row)
     return result_list
Example #2
0
 def _execute_object_filter(self, sub_expression: Union[str, List]) -> Set[Object]:
     """
     Object filtering functions should either be a string referring to all objects, or list which
     executes to a filtering operation.
     The elements should evaluate to one of the following:
         (object_filtering_function object_set)
         ((negate_filter object_filtering_function) object_set)
         all_objects
     """
     if sub_expression[0][0] == "negate_filter":
         initial_set = self._execute_object_filter(sub_expression[1])
         original_filter_name = sub_expression[0][1]
         # It is possible that the decoder has produced a sequence of nested negations. We deal
         # with that here.
         # TODO (pradeep): This is messy. Fix the type declaration so that we don't have to deal
         # with this.
         num_negations = 1
         while isinstance(original_filter_name, list) and \
               original_filter_name[0] == "negate_filter":
             # We have a sequence of "negate_filters"
             num_negations += 1
             original_filter_name = original_filter_name[1]
         if num_negations % 2 == 0:
             return initial_set
         try:
             original_filter = getattr(self, original_filter_name)
             return self.negate_filter(original_filter, initial_set)
         except AttributeError:
             logger.error("Function not found: %s", original_filter_name)
             raise ExecutionError("Function not found")
     elif sub_expression == "all_objects" or sub_expression[0] == "all_objects":
         return self._objects
     elif isinstance(sub_expression[0], str) and len(sub_expression) == 2:
         # These are functions like black, square, same_color etc.
         function = None
         try:
             function = getattr(self, sub_expression[0])
         except AttributeError:
             logger.error("Function not found: %s", sub_expression[0])
             raise ExecutionError("Function not found")
         arguments = sub_expression[1]
         if isinstance(arguments, list) and str(arguments[0]).startswith("member_") or \
             arguments == 'all_boxes' or arguments[0] == 'all_boxes':
             if sub_expression[0] != "object_in_box":
                 logger.error("Invalid object filter expression: %s", sub_expression)
                 raise ExecutionError("Invalid object filter expression")
             return function(self._execute_box_filter(arguments))
         else:
             return function(self._execute_object_filter(arguments))
     else:
         logger.error("Invalid object filter expression: %s", sub_expression)
         raise ExecutionError("Invalid object filter expression")
 def diff(self, first_row_expression_list: NestedList,
          second_row_expression_list: NestedList,
          column_name: str) -> float:
     """
     Takes an expressions that evaluate to two rows, and a column name, and returns the
     difference between the values under that column in those two rows.
     """
     first_row_list = self._handle_expression(first_row_expression_list)
     second_row_list = self._handle_expression(second_row_expression_list)
     if not first_row_list or not second_row_list:
         return 0.0
     if len(first_row_list) > 1:
         logger.warning(
             "diff got multiple rows for first argument. Taking the first one: "
             f"{first_row_expression_list}")
     if len(second_row_list) > 1:
         logger.warning(
             "diff got multiple rows for second argument. Taking the first one: "
             f"{second_row_expression_list}")
     first_row = first_row_list[0]
     second_row = second_row_list[0]
     try:
         first_value = float(first_row[column_name])
         second_value = float(second_row[column_name])
         return first_value - second_value
     except ValueError:
         raise ExecutionError(f"Invalid column for diff: {column_name}")
Example #4
0
 def _execute_constant(sub_expression: str):
     """
     Acceptable constants are numbers or strings starting with `shape_` or `color_`
     """
     if not isinstance(sub_expression, str):
         logger.error("Invalid constant: %s", sub_expression)
         raise ExecutionError("Invalid constant")
     if str.isdigit(sub_expression):
         return int(sub_expression)
     elif sub_expression.startswith('color_'):
         return sub_expression.replace('color_', '')
     elif sub_expression.startswith('shape_'):
         return sub_expression.replace('shape_', '')
     else:
         logger.error("Invalid constant: %s", sub_expression)
         raise ExecutionError("Invalid constant")
 def _handle_constant(self,
                      constant: str) -> Union[List[Dict[str, str]], float]:
     if constant == "all_rows":
         return self._table_data
     try:
         return float(constant)
     except ValueError:
         raise ExecutionError(f"Cannot handle constant: {constant}")
 def _handle_constant(self, constant: str) -> Union[RowListType, str, float]:
     if constant == "all_rows":
         return self.table_data
     try:
         return float(constant)
     except ValueError:
         # The constant is not a number. Returning as-is if it is a string.
         if constant.startswith("string:"):
             return constant.replace("string:", "")
         raise ExecutionError(f"Cannot handle constant: {constant}")
 def date(year_string: str, month_string: str, day_string: str) -> Date:
     """
     Takes three numbers as strings, and returns a ``Date`` object whose year, month, and day are
     the three numbers in that order.
     """
     try:
         year = int(str(year_string))
         month = int(str(month_string))
         day = int(str(day_string))
         return Date(year, month, day)
     except ValueError:
         raise ExecutionError(f"Invalid date! Got {year_string}, {month_string}, {day_string}")
 def _handle_expression(self, expression_list):
     if isinstance(expression_list, list) and len(expression_list) == 1:
         expression = expression_list[0]
     else:
         expression = expression_list
     if isinstance(expression, list):
         # This is a function application.
         function_name = expression[0]
     else:
         # This is a constant (like "all_rows" or "2005")
         return self._handle_constant(expression)
     try:
         function = getattr(self, function_name)
         return function(*expression[1:])
     except AttributeError:
         raise ExecutionError(f"Function not found: {function_name}")
 def filter_date_not_equals(
         self, row_expression_list: NestedList, column_name: str,
         value_expression: NestedList) -> List[Dict[str, str]]:
     """
     Takes a list of rows, a column, and a numerical value and returns all the rows where the
     value in that column is not equal to the given value.
     """
     row_list = self._handle_expression(row_expression_list)
     if not row_list:
         return []
     cell_row_pairs = self._get_date_row_pairs_to_filter(
         row_list, column_name)
     filter_value = self._handle_expression(value_expression)
     if not isinstance(filter_value, Date):
         raise ExecutionError(f"Invalid filter value: {value_expression}")
     return_list = []
     for cell_value, row in cell_row_pairs:
         if cell_value != filter_value:
             return_list.append(row)
     return return_list
 def filter_number_lesser(
         self, row_expression_list: NestedList, column_name: str,
         value_expression: NestedList) -> List[Dict[str, str]]:
     """
     Takes a list of rows as an expression, a column, and a numerical value expression and
     returns all the rows where the value in that column is less than the given value.
     """
     row_list = self._handle_expression(row_expression_list)
     if not row_list:
         return []
     cell_row_pairs = self._get_number_row_pairs_to_filter(
         row_list, column_name)
     filter_value = self._handle_expression(value_expression)
     if not isinstance(filter_value, float):
         raise ExecutionError(f"Invalid filter value: {value_expression}")
     return_list = []
     for cell_value, row in cell_row_pairs:
         if cell_value < filter_value:
             return_list.append(row)
     return return_list
 def filter_date_greater(self,
                         row_expression_list: NestedList,
                         column_name: str,
                         value_expression: NestedList) -> RowListType:
     """
     Takes a list of rows as an expression, a column, and a numerical value expression and
     returns all the rows where the value in that column is greater than the given value.
     """
     row_list = self._handle_expression(row_expression_list)
     if not row_list:
         return []
     cell_row_pairs = self._get_date_row_pairs_to_filter(row_list, column_name)
     filter_value = self._handle_expression(value_expression)
     if not isinstance(filter_value, Date):
         raise ExecutionError(f"Invalid filter value: {value_expression}")
     return_list = []
     for cell_value, row in cell_row_pairs:
         if cell_value > filter_value:
             return_list.append(row)
     return return_list
Example #12
0
    def _execute_box_filter(self, sub_expression: Union[str, List]) -> Set[Box]:
        """
        Box filtering functions either apply a filter on a set of boxes and return the filtered set,
        or return all the boxes.
        The elements should evaluate to one of the following:
        ``(box_filtering_function set_to_filter constant)`` or
        ``all_boxes``

        In the first kind of forms, the ``box_filtering_function`` also specifies the attribute
        being compared and the comparison operator. The attribute is of the objects contained in
        each box in the ``set_to_filter``.
        Example: ``(member_color_count_greater all_boxes 1)``
        filters all boxes by extracting the colors of the objects in each of them, and returns a
        subset of boxes from the original set where the number of colors of objects is greater than
        1.
        """
        # TODO(pradeep): We may want to change the order of arguments here to make decoding easier.
        if sub_expression[0].startswith('member_'):
            function_name_parts = sub_expression[0].split("_")
            if len(function_name_parts) == 3:
                attribute_type = function_name_parts[1]
                comparison_op = function_name_parts[2]
            elif function_name_parts[2] == "count":
                attribute_type = "_".join(function_name_parts[1:3])
                comparison_op = "_".join(function_name_parts[3:])
            else:
                attribute_type = function_name_parts[1]
                comparison_op = "_".join(function_name_parts[2:])
            set_to_filter = self._execute_box_filter(sub_expression[1])
            return_set = set()
            if comparison_op in ["same", "different"]:
                # We don't need a target attribute for these functions, and the "comparison" is done
                # on sets.
                comparison_function = self._set_unary_operators[comparison_op]
                for box in set_to_filter:
                    returned_attribute: Set[str] = self._attribute_functions[attribute_type](box.objects)
                    if comparison_function(returned_attribute):
                        return_set.add(box)
            else:
                target_attribute = self._execute_constant(sub_expression[-1])
                is_set_operation = comparison_op in ["all_equals", "any_equals", "none_equals"]
                # These are comparisons like equals, greater etc, and we need a target attribute
                # which we first evaluate here. Then, the returned attribute (if it is a singleton
                # set or an integer), is compared against the target attribute.
                for box in set_to_filter:
                    if is_set_operation:
                        returned_attribute = self._attribute_functions[attribute_type](box.objects)
                        box_wanted = self._set_binary_operators[comparison_op](returned_attribute,
                                                                               target_attribute)
                    else:
                        returned_count = self._count_functions[attribute_type](box.objects)
                        box_wanted = self._number_operators[comparison_op](returned_count,
                                                                           target_attribute)
                    if box_wanted:
                        return_set.add(box)
            return return_set
        elif sub_expression == 'all_boxes' or sub_expression[0] == 'all_boxes':
            return self._boxes
        else:
            logger.error("Invalid box filter expression: %s", sub_expression)
            raise ExecutionError("Unknown box filter expression")
Example #13
0
    def _execute_assertion(self, sub_expression: List) -> bool:
        """
        Assertion functions are boolean functions. They are of two types:
        1) Exists functions: They take one argument, a set and check whether it is not empty.
        Syntax: ``(exists_function function_returning_entities)``
        Example: ``(object_exists (black (top all_objects)))`` ("There is a black object at the top
        of a tower.")
        2) Other assert functions: They take two arguments, which evaluate to strings or integers,
        and compare them. The first element in the input list should be the assertion function name,
        the second a function returning entities, and the last element should be a constant. The
        assertion function should specify the entity type, the attribute being compared, and a
        comparison operator, in that order separated by underscores. The following are the expected
        values:
            Entity types: ``object``, ``box``
            Attributes being compared: ``color``, ``shape``, ``count``, ``color_count``,
            ``shape_count``
            Comparison operator:
                Applicable to sets: ``all_equals``, ``any_equals``, ``none_equals``, ``same``,
                ``different``
                Applicable to counts: ``equals``, ``not_equals``, ``lesser``, ``lesser_equals``,
                ``greater``, ``greater_equals``
        Syntax: ``(assertion_function function_returning_entities constant)``
        Example: ``(box_count_equals (member_shape_equals all_boxes shape_square) 2)``
        ("There are exactly two boxes with only squares in them")

        Note that the first kind is a special case of the second where the attribute type is
        ``count``, comparison operator is ``greater_equals`` and the constant is ``1``.
        """
        # TODO(pradeep): We may want to change the order of arguments here to make decoding easier.
        assert isinstance(sub_expression, list), "Invalid assertion expression: %s" % sub_expression
        if len(sub_expression) == 1 and isinstance(sub_expression[0], list):
            return self._execute_assertion(sub_expression[0])
        is_assert_function = sub_expression[0].startswith("object_") or \
        sub_expression[0].startswith("box_")
        assert isinstance(sub_expression[0], str) and is_assert_function,\
               "Invalid assertion function: %s" % (sub_expression[0])
        # Example: box_count_not_equals, entities being evaluated are boxes, the relevant attibute
        # is their count, and the function will return true if the attribute is not equal to the
        # target.
        function_name_parts = sub_expression[0].split('_')
        entity_type = function_name_parts[0]
        target_attribute = None
        if len(function_name_parts) == 2 and function_name_parts[1] == "exists":
            attribute_type = "count"
            comparison_op = "greater_equals"
            target_attribute = 1
        else:
            target_attribute = self._execute_constant(sub_expression[2])
            # If the length of ``function_name_parts`` is 3, getting the attribute and comparison
            # operator is easy. However, if it is greater than 3, we need to determine where the
            # attribute function stops and where the comparison operator begins.
            if len(function_name_parts) == 3:
                # These are cases like ``object_color_equals``, ``box_count_greater`` etc.
                attribute_type = function_name_parts[1]
                comparison_op = function_name_parts[2]
            elif function_name_parts[2] == 'count':
                # These are cases like ``object_color_count_equals``,
                # ``object_shape_count_greater_equals`` etc.
                attribute_type = "_".join(function_name_parts[1:3])
                comparison_op = "_".join(function_name_parts[3:])
            else:
                # These are cases like ``box_count_greater_equals``, ``object_shape_not_equals``
                # etc.
                attribute_type = function_name_parts[1]
                comparison_op = "_".join(function_name_parts[2:])

        entity_expression = sub_expression[1]
        returned_count = None
        returned_attribute = None
        if entity_type == "box":
            # You can only count boxes. The other attributes do not apply.
            returned_count = self._count(self._execute_box_filter(entity_expression))
        elif "count" in attribute_type:
            # We're counting objects, colors or shapes.
            count_function = self._count_functions[attribute_type]
            returned_count = count_function(self._execute_object_filter(entity_expression))
        else:
            # We're getting colors or shapes from objects.
            attribute_function = self._attribute_functions[attribute_type]
            returned_attribute = attribute_function(self._execute_object_filter(entity_expression))

        if comparison_op in ["all_equals", "any_equals", "none_equals"]:
            set_comparison = self._set_binary_operators[comparison_op]
            if returned_attribute is None:
                logger.error("Invalid assertion function: %s", sub_expression[0])
                raise ExecutionError("Invalid assertion function")
            return set_comparison(returned_attribute, target_attribute)
        else:
            number_comparison = self._number_operators[comparison_op]
            if returned_count is None:
                logger.error("Invalid assertion function: %s", sub_expression[0])
                raise ExecutionError("Invalid assertion function")
            return number_comparison(returned_count, target_attribute)