Exemple #1
0
    def test_unary_ops_resolve_correctly(self):
        unary_type = UnaryOpType()

        # Resolution should fail against a basic type
        assert unary_type.resolve(ROW_TYPE) is None

        # Resolution should fail against a complex type where the argument and return types are not same
        assert unary_type.resolve(ComplexType(CELL_TYPE, ROW_TYPE)) is None

        # Resolution should resolve ANY_TYPE given the other type
        resolution = unary_type.resolve(ComplexType(ANY_TYPE, ROW_TYPE))
        assert resolution == UnaryOpType(ROW_TYPE)
        resolution = unary_type.resolve(ComplexType(CELL_TYPE, ANY_TYPE))
        assert resolution == UnaryOpType(CELL_TYPE)

        reverse_type = ComplexType(
            ComplexType(CELL_TYPE, ROW_TYPE), ComplexType(CELL_TYPE, ROW_TYPE)
        )
        resolution = unary_type.resolve(reverse_type)
        assert resolution == UnaryOpType(ComplexType(CELL_TYPE, ROW_TYPE))
Exemple #2
0
 def test_get_valid_actions_with_any_type(self):
     type_r = NamedBasicType("R")
     type_d = NamedBasicType("D")
     type_e = NamedBasicType("E")
     name_mapping = {'sample_function': 'F'}
     # The purpose of this test is to ensure that ANY_TYPE gets substituted by every possible basic type,
     # to simulate an intermediate step while getting actions for a placeholder type.
     # I do not foresee defining a function type with ANY_TYPE. We should just use a ``PlaceholderType``
     # instead.
     # <?,r>
     type_signatures = {'F': ComplexType(ANY_TYPE, type_r)}
     basic_types = {type_r, type_d, type_e}
     valid_actions = types.get_valid_actions(name_mapping, type_signatures,
                                             basic_types)
     assert len(valid_actions) == 5
     assert valid_actions["<d,r>"] == ["<d,r> -> sample_function"]
     assert valid_actions["<e,r>"] == ["<e,r> -> sample_function"]
     assert valid_actions["<r,r>"] == ["<r,r> -> sample_function"]
     assert valid_actions["r"] == [
         "r -> [<d,r>, d]", "r -> [<e,r>, e]", "r -> [<r,r>, r]"
     ]
     assert valid_actions["@start@"] == [
         "@start@ -> d", "@start@ -> e", "@start@ -> r"
     ]
Exemple #3
0
    def test_binary_ops_resolve_correctly(self):
        binary_type = BinaryOpType()

        # Resolution must fail against a basic type and a complex type that returns a basic type
        assert binary_type.resolve(CELL_TYPE) is None
        assert binary_type.resolve(ComplexType(CELL_TYPE, ROW_TYPE)) is None

        # Resolution must fail against incompatible types
        complex_type = ComplexType(ANY_TYPE, ComplexType(CELL_TYPE, ROW_TYPE))
        assert binary_type.resolve(complex_type) is None

        complex_type = ComplexType(ROW_TYPE, ComplexType(CELL_TYPE, ANY_TYPE))
        assert binary_type.resolve(complex_type) is None

        complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, CELL_TYPE))
        assert binary_type.resolve(complex_type) is None

        # Resolution must resolve any types appropriately
        complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, ROW_TYPE))
        assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE)

        complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, ANY_TYPE))
        assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE)

        complex_type = ComplexType(ANY_TYPE, ComplexType(ROW_TYPE, ANY_TYPE))
        assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE)
    def __init__(self, syntax: str) -> None:

        self.name_mapper = NameMapper()

        num_type = NamedBasicType("NUM")
        attr_type = NamedBasicType("ATTR")
        rdir_type = NamedBasicType("RDIR")
        world_type = NamedBasicType("WORLD")
        var_type = NamedBasicType("VAR")

        self.basic_types = {
            num_type, attr_type, rdir_type, world_type, var_type
        }

        if syntax == "quarel_friction":
            # attributes: <<QDIR, <WORLD, ATTR>>
            attr_function_type = ComplexType(
                rdir_type, ComplexType(world_type, attr_type))

            and_function_type = ComplexType(attr_type,
                                            ComplexType(attr_type, attr_type))

            # infer: <ATTR, <ATTR, <ATTR, NUM>>>
            infer_function_type = ComplexType(
                attr_type,
                ComplexType(attr_type, ComplexType(attr_type, num_type)))
            self.name_mapper.map_name_with_signature("infer",
                                                     infer_function_type)
            # Attributes
            self.name_mapper.map_name_with_signature("friction",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("smoothness",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("speed",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("heat",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("distance",
                                                     attr_function_type)

            # For simplicity we treat "high" and "low" as directions as well
            self.name_mapper.map_name_with_signature("high", rdir_type)
            self.name_mapper.map_name_with_signature("low", rdir_type)
            self.name_mapper.map_name_with_signature("and", and_function_type)

            self.curried_functions = {
                attr_function_type: 2,
                infer_function_type: 3,
                and_function_type: 2,
            }
        elif syntax in ("quarel_v1_attr_entities",
                        "quarel_friction_attr_entities"):
            # attributes: <<QDIR, <WORLD, ATTR>>
            attr_function_type = ComplexType(
                rdir_type, ComplexType(world_type, attr_type))

            and_function_type = ComplexType(attr_type,
                                            ComplexType(attr_type, attr_type))

            # infer: <ATTR, <ATTR, <ATTR, NUM>>>
            infer_function_type = ComplexType(
                attr_type,
                ComplexType(attr_type, ComplexType(attr_type, num_type)))
            self.name_mapper.map_name_with_signature("infer",
                                                     infer_function_type)
            # TODO: Remove this?
            self.name_mapper.map_name_with_signature("placeholder",
                                                     attr_function_type)

            # For simplicity we treat "high" and "low" as directions as well
            self.name_mapper.map_name_with_signature("high", rdir_type)
            self.name_mapper.map_name_with_signature("low", rdir_type)
            self.name_mapper.map_name_with_signature("and", and_function_type)

            self.curried_functions = {
                attr_function_type: 2,
                infer_function_type: 3,
                and_function_type: 2,
            }

        elif syntax == "quarel_v1":
            # attributes: <<QDIR, <WORLD, ATTR>>
            attr_function_type = ComplexType(
                rdir_type, ComplexType(world_type, attr_type))

            and_function_type = ComplexType(attr_type,
                                            ComplexType(attr_type, attr_type))

            # infer: <ATTR, <ATTR, <ATTR, NUM>>>
            infer_function_type = ComplexType(
                attr_type,
                ComplexType(attr_type, ComplexType(attr_type, num_type)))
            self.name_mapper.map_name_with_signature("infer",
                                                     infer_function_type)
            # Attributes
            self.name_mapper.map_name_with_signature("friction",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("smoothness",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("speed",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("heat",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("distance",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("acceleration",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("amountSweat",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("apparentSize",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("breakability",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("brightness",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("exerciseIntensity",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("flexibility",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("gravity",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("loudness",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("mass",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("strength",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("thickness",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("time",
                                                     attr_function_type)
            self.name_mapper.map_name_with_signature("weight",
                                                     attr_function_type)

            # For simplicity we treat "high" and "low" as directions as well
            self.name_mapper.map_name_with_signature("high", rdir_type)
            self.name_mapper.map_name_with_signature("low", rdir_type)
            self.name_mapper.map_name_with_signature("and", and_function_type)

            self.curried_functions = {
                attr_function_type: 2,
                infer_function_type: 3,
                and_function_type: 2,
            }

        else:
            raise Exception(f"Unknown LF syntax specification: {syntax}")

        self.name_mapper.map_name_with_signature("higher", rdir_type)
        self.name_mapper.map_name_with_signature("lower", rdir_type)

        self.name_mapper.map_name_with_signature("world1", world_type)
        self.name_mapper.map_name_with_signature("world2", world_type)

        # Hack to expose types
        self.world_type = world_type
        self.attr_function_type = attr_function_type
        self.var_type = var_type

        self.starting_types = {num_type}
    @overrides
    def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]:
        # There's no ANY_TYPE in here, so we don't need to do any substitution.
        return [self]


# All constants default to ``EntityType`` in NLTK. For domains where constants of different types
# appear in the logical forms, we have a way of specifying ``constant_type_prefixes`` and passing
# them to the constructor of ``World``. However, in the NLVR language we defined, we see constants
# of just one type, number. So we let them default to ``EntityType``.
NUM_TYPE = EntityType()
BOX_TYPE = NamedBasicType("BOX")
OBJECT_TYPE = NamedBasicType("OBJECT")
COLOR_TYPE = NamedBasicType("COLOR")
SHAPE_TYPE = NamedBasicType("SHAPE")
OBJECT_FILTER_TYPE = ComplexType(OBJECT_TYPE, OBJECT_TYPE)
NEGATE_FILTER_TYPE = NegateFilterType(ComplexType(OBJECT_TYPE, OBJECT_TYPE),
                                      ComplexType(OBJECT_TYPE, OBJECT_TYPE))
BOX_MEMBERSHIP_TYPE = ComplexType(BOX_TYPE, OBJECT_TYPE)

BOX_COLOR_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(COLOR_TYPE, BOX_TYPE))
BOX_SHAPE_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(SHAPE_TYPE, BOX_TYPE))
BOX_COUNT_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(NUM_TYPE, BOX_TYPE))
# This box filter returns boxes where a specified attribute is same or different
BOX_ATTRIBUTE_SAME_FILTER_TYPE = ComplexType(BOX_TYPE, BOX_TYPE)


ASSERT_COLOR_TYPE = ComplexType(OBJECT_TYPE, ComplexType(COLOR_TYPE, TRUTH_TYPE))
ASSERT_SHAPE_TYPE = ComplexType(OBJECT_TYPE, ComplexType(SHAPE_TYPE, TRUTH_TYPE))
ASSERT_BOX_COUNT_TYPE = ComplexType(BOX_TYPE, ComplexType(NUM_TYPE, TRUTH_TYPE))
ASSERT_OBJECT_COUNT_TYPE = ComplexType(OBJECT_TYPE, ComplexType(NUM_TYPE, TRUTH_TYPE))
 def get_application_type(self, argument_type: Type) -> Type:
     return ComplexType(argument_type.second, argument_type.first)
    @overrides
    def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]:
        if self.first != ANY_TYPE:
            return [self]
        return [CountType(basic_type) for basic_type in basic_types]


CELL_TYPE = NamedBasicType("CELL")
PART_TYPE = NamedBasicType("PART")
ROW_TYPE = NamedBasicType("ROW")
DATE_TYPE = NamedBasicType("DATE")
NUMBER_TYPE = NamedBasicType("NUMBER")

BASIC_TYPES = {CELL_TYPE, PART_TYPE, ROW_TYPE, DATE_TYPE, NUMBER_TYPE}
# Functions like fb:row.row.year.
COLUMN_TYPE = ComplexType(CELL_TYPE, ROW_TYPE)
# fb:cell.cell.part
PART_TO_CELL_TYPE = ComplexType(PART_TYPE, CELL_TYPE)
# fb:cell.cell.date
DATE_TO_CELL_TYPE = ComplexType(DATE_TYPE, CELL_TYPE)
# fb:cell.cell.number
NUM_TO_CELL_TYPE = ComplexType(NUMBER_TYPE, CELL_TYPE)
# number
NUMBER_FUNCTION_TYPE = ComplexType(NUMBER_TYPE, NUMBER_TYPE)
# date (Signature: <e,<e,<e,d>>>; Example: (date 1982 -1 -1))
DATE_FUNCTION_TYPE = ComplexType(NUMBER_TYPE,
                                 ComplexType(NUMBER_TYPE, ComplexType(NUMBER_TYPE, DATE_TYPE)))
# Unary numerical operations: max, min, >, <, sum etc.
UNARY_DATE_NUM_OP_TYPE = UnaryOpType(allowed_substitutions={DATE_TYPE, NUMBER_TYPE},
                                     signature='<nd,nd>')
UNARY_NUM_OP_TYPE = ComplexType(NUMBER_TYPE, NUMBER_TYPE)
GENERIC_COLUMN_TYPE = MultiMatchNamedBasicType("GCOLUMN", [STRING_COLUMN_TYPE, DATE_COLUMN_TYPE,
                                                           NUMBER_COLUMN_TYPE])
COMPARABLE_COLUMN_TYPE = MultiMatchNamedBasicType("CCOLUMN", [NUMBER_COLUMN_TYPE, DATE_COLUMN_TYPE])

NUMBER_TYPE = NamedBasicType("NUMBER")
DATE_TYPE = NamedBasicType("DATE")
STRING_TYPE = NamedBasicType("STRING")

# We start with just the generic column type, and add the specific column types to the set only if
# we see the corresponding types in the table.
BASIC_TYPES = {ROW_TYPE, GENERIC_COLUMN_TYPE, NUMBER_TYPE, DATE_TYPE, STRING_TYPE}
STARTING_TYPES = {NUMBER_TYPE, DATE_TYPE, STRING_TYPE}

# Complex types
# Type for selecting the value in a column in a set of rows. "select" and "mode" functions.
SELECT_TYPE = ComplexType(ROW_TYPE, ComplexType(GENERIC_COLUMN_TYPE, STRING_TYPE))

# Type for filtering rows given a column. "argmax", "argmin" and "same_as" (select all rows with the
# same value under the given column as the given row does under the given column). While "same_as"
# takes any column, "argmax" and "argmin" take only comparable columns (i.e. dates or numbers).
# Note that the values used for comparison in "argmax" and "argmin" can only come from column
# lookups in this language. In LambdaDCS, there's a lambda function that is applied to the rows to
# get the values, but here, we simply have a column name.
ROW_FILTER_WITH_GENERIC_COLUMN = ComplexType(ROW_TYPE, ComplexType(GENERIC_COLUMN_TYPE, ROW_TYPE))
ROW_FILTER_WITH_COMPARABLE_COLUMN = ComplexType(ROW_TYPE, ComplexType(COMPARABLE_COLUMN_TYPE, ROW_TYPE))

# "filter_number_greater", "filter_number_equals" etc.
ROW_FILTER_WITH_COLUMN_AND_NUMBER = ComplexType(ROW_TYPE,
                                                ComplexType(NUMBER_COLUMN_TYPE,
                                                            ComplexType(NUMBER_TYPE, ROW_TYPE)))
Exemple #9
0
    def test_reverse_resolves_correctly(self):
        assert REVERSE_TYPE.resolve(CELL_TYPE) is None

        # Resolving against <?,<e,r>> should give <<r,e>,<e,r>>
        resolution = REVERSE_TYPE.resolve(
            ComplexType(ANY_TYPE, ComplexType(CELL_TYPE, ROW_TYPE)))
        assert resolution == ReverseType(ComplexType(ROW_TYPE, CELL_TYPE),
                                         ComplexType(CELL_TYPE, ROW_TYPE))

        # Resolving against <<r,?>,<e,?>> should give <<r,e>,<e,r>>
        resolution = REVERSE_TYPE.resolve(
            ComplexType(ComplexType(ROW_TYPE, ANY_TYPE),
                        ComplexType(CELL_TYPE, ANY_TYPE)))
        assert resolution == ReverseType(ComplexType(ROW_TYPE, CELL_TYPE),
                                         ComplexType(CELL_TYPE, ROW_TYPE))

        # Resolving against <<r,?>,?> should give <<r,?>,<?,r>>
        resolution = REVERSE_TYPE.resolve(
            ComplexType(ComplexType(ROW_TYPE, ANY_TYPE), ANY_TYPE))
        assert resolution == ReverseType(ComplexType(ROW_TYPE, ANY_TYPE),
                                         ComplexType(ANY_TYPE, ROW_TYPE))

        # Resolving against <<r,?>,<?,e>> should give None
        resolution = REVERSE_TYPE.resolve(
            ComplexType(ComplexType(ROW_TYPE, ANY_TYPE),
                        ComplexType(ANY_TYPE, CELL_TYPE)))
        assert resolution is None
Exemple #10
0
    NamedBasicType, ComplexType, NameMapper)

# Basic types
PAS_TYPE = NamedBasicType("PAS")  # Predicate Argument Structure
RELATION_TYPE = NamedBasicType("RELATION")

NUMBER_TYPE = NamedBasicType("NUMBER")
DATE_TYPE = NamedBasicType("DATE")
STRING_TYPE = NamedBasicType("STRING")

BASIC_TYPES = {PAS_TYPE, RELATION_TYPE, NUMBER_TYPE, DATE_TYPE, STRING_TYPE}
STARTING_TYPES = {NUMBER_TYPE, DATE_TYPE, STRING_TYPE}

# Complex types
# Type for selecting the value in a column in a set of rows. "select", "mode", and "extract_entity" functions.
SELECT_TYPE = ComplexType(PAS_TYPE, ComplexType(RELATION_TYPE, STRING_TYPE))

# Type for filtering structures given a relation. "argmax", "argmin"
PAS_FILTER_WITH_RELATION = ComplexType(PAS_TYPE,
                                       ComplexType(RELATION_TYPE, PAS_TYPE))

# "filter_number_greater", "filter_number_equals" etc.
PAS_FILTER_WITH_RELATION_AND_NUMBER = ComplexType(
    PAS_TYPE, ComplexType(RELATION_TYPE, ComplexType(NUMBER_TYPE, PAS_TYPE)))

# "filter_date_greater", "filter_date_equals" etc.
PAS_FILTER_WITH_RELATION_AND_DATE = ComplexType(
    PAS_TYPE, ComplexType(RELATION_TYPE, ComplexType(DATE_TYPE, PAS_TYPE)))

# "filter_in" and "filter_not_in"
PAS_FILTER_WITH_RELATION_AND_STRING = ComplexType(