def test_unary_ops_resolve_correctly(self): unary_type = UnaryOpType() # Resolution should fail against a basic type assert unary_type.resolve(ROW_TYPE) is None # Resolution should fail against a complex type where the argument and return types are not same assert unary_type.resolve(ComplexType(CELL_TYPE, ROW_TYPE)) is None # Resolution should resolve ANY_TYPE given the other type resolution = unary_type.resolve(ComplexType(ANY_TYPE, ROW_TYPE)) assert resolution == UnaryOpType(ROW_TYPE) resolution = unary_type.resolve(ComplexType(CELL_TYPE, ANY_TYPE)) assert resolution == UnaryOpType(CELL_TYPE) reverse_type = ComplexType( ComplexType(CELL_TYPE, ROW_TYPE), ComplexType(CELL_TYPE, ROW_TYPE) ) resolution = unary_type.resolve(reverse_type) assert resolution == UnaryOpType(ComplexType(CELL_TYPE, ROW_TYPE))
def test_get_valid_actions_with_any_type(self): type_r = NamedBasicType("R") type_d = NamedBasicType("D") type_e = NamedBasicType("E") name_mapping = {'sample_function': 'F'} # The purpose of this test is to ensure that ANY_TYPE gets substituted by every possible basic type, # to simulate an intermediate step while getting actions for a placeholder type. # I do not foresee defining a function type with ANY_TYPE. We should just use a ``PlaceholderType`` # instead. # <?,r> type_signatures = {'F': ComplexType(ANY_TYPE, type_r)} basic_types = {type_r, type_d, type_e} valid_actions = types.get_valid_actions(name_mapping, type_signatures, basic_types) assert len(valid_actions) == 5 assert valid_actions["<d,r>"] == ["<d,r> -> sample_function"] assert valid_actions["<e,r>"] == ["<e,r> -> sample_function"] assert valid_actions["<r,r>"] == ["<r,r> -> sample_function"] assert valid_actions["r"] == [ "r -> [<d,r>, d]", "r -> [<e,r>, e]", "r -> [<r,r>, r]" ] assert valid_actions["@start@"] == [ "@start@ -> d", "@start@ -> e", "@start@ -> r" ]
def test_binary_ops_resolve_correctly(self): binary_type = BinaryOpType() # Resolution must fail against a basic type and a complex type that returns a basic type assert binary_type.resolve(CELL_TYPE) is None assert binary_type.resolve(ComplexType(CELL_TYPE, ROW_TYPE)) is None # Resolution must fail against incompatible types complex_type = ComplexType(ANY_TYPE, ComplexType(CELL_TYPE, ROW_TYPE)) assert binary_type.resolve(complex_type) is None complex_type = ComplexType(ROW_TYPE, ComplexType(CELL_TYPE, ANY_TYPE)) assert binary_type.resolve(complex_type) is None complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, CELL_TYPE)) assert binary_type.resolve(complex_type) is None # Resolution must resolve any types appropriately complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, ROW_TYPE)) assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE) complex_type = ComplexType(ROW_TYPE, ComplexType(ANY_TYPE, ANY_TYPE)) assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE) complex_type = ComplexType(ANY_TYPE, ComplexType(ROW_TYPE, ANY_TYPE)) assert binary_type.resolve(complex_type) == BinaryOpType(ROW_TYPE)
def __init__(self, syntax: str) -> None: self.name_mapper = NameMapper() num_type = NamedBasicType("NUM") attr_type = NamedBasicType("ATTR") rdir_type = NamedBasicType("RDIR") world_type = NamedBasicType("WORLD") var_type = NamedBasicType("VAR") self.basic_types = { num_type, attr_type, rdir_type, world_type, var_type } if syntax == "quarel_friction": # attributes: <<QDIR, <WORLD, ATTR>> attr_function_type = ComplexType( rdir_type, ComplexType(world_type, attr_type)) and_function_type = ComplexType(attr_type, ComplexType(attr_type, attr_type)) # infer: <ATTR, <ATTR, <ATTR, NUM>>> infer_function_type = ComplexType( attr_type, ComplexType(attr_type, ComplexType(attr_type, num_type))) self.name_mapper.map_name_with_signature("infer", infer_function_type) # Attributes self.name_mapper.map_name_with_signature("friction", attr_function_type) self.name_mapper.map_name_with_signature("smoothness", attr_function_type) self.name_mapper.map_name_with_signature("speed", attr_function_type) self.name_mapper.map_name_with_signature("heat", attr_function_type) self.name_mapper.map_name_with_signature("distance", attr_function_type) # For simplicity we treat "high" and "low" as directions as well self.name_mapper.map_name_with_signature("high", rdir_type) self.name_mapper.map_name_with_signature("low", rdir_type) self.name_mapper.map_name_with_signature("and", and_function_type) self.curried_functions = { attr_function_type: 2, infer_function_type: 3, and_function_type: 2, } elif syntax in ("quarel_v1_attr_entities", "quarel_friction_attr_entities"): # attributes: <<QDIR, <WORLD, ATTR>> attr_function_type = ComplexType( rdir_type, ComplexType(world_type, attr_type)) and_function_type = ComplexType(attr_type, ComplexType(attr_type, attr_type)) # infer: <ATTR, <ATTR, <ATTR, NUM>>> infer_function_type = ComplexType( attr_type, ComplexType(attr_type, ComplexType(attr_type, num_type))) self.name_mapper.map_name_with_signature("infer", infer_function_type) # TODO: Remove this? self.name_mapper.map_name_with_signature("placeholder", attr_function_type) # For simplicity we treat "high" and "low" as directions as well self.name_mapper.map_name_with_signature("high", rdir_type) self.name_mapper.map_name_with_signature("low", rdir_type) self.name_mapper.map_name_with_signature("and", and_function_type) self.curried_functions = { attr_function_type: 2, infer_function_type: 3, and_function_type: 2, } elif syntax == "quarel_v1": # attributes: <<QDIR, <WORLD, ATTR>> attr_function_type = ComplexType( rdir_type, ComplexType(world_type, attr_type)) and_function_type = ComplexType(attr_type, ComplexType(attr_type, attr_type)) # infer: <ATTR, <ATTR, <ATTR, NUM>>> infer_function_type = ComplexType( attr_type, ComplexType(attr_type, ComplexType(attr_type, num_type))) self.name_mapper.map_name_with_signature("infer", infer_function_type) # Attributes self.name_mapper.map_name_with_signature("friction", attr_function_type) self.name_mapper.map_name_with_signature("smoothness", attr_function_type) self.name_mapper.map_name_with_signature("speed", attr_function_type) self.name_mapper.map_name_with_signature("heat", attr_function_type) self.name_mapper.map_name_with_signature("distance", attr_function_type) self.name_mapper.map_name_with_signature("acceleration", attr_function_type) self.name_mapper.map_name_with_signature("amountSweat", attr_function_type) self.name_mapper.map_name_with_signature("apparentSize", attr_function_type) self.name_mapper.map_name_with_signature("breakability", attr_function_type) self.name_mapper.map_name_with_signature("brightness", attr_function_type) self.name_mapper.map_name_with_signature("exerciseIntensity", attr_function_type) self.name_mapper.map_name_with_signature("flexibility", attr_function_type) self.name_mapper.map_name_with_signature("gravity", attr_function_type) self.name_mapper.map_name_with_signature("loudness", attr_function_type) self.name_mapper.map_name_with_signature("mass", attr_function_type) self.name_mapper.map_name_with_signature("strength", attr_function_type) self.name_mapper.map_name_with_signature("thickness", attr_function_type) self.name_mapper.map_name_with_signature("time", attr_function_type) self.name_mapper.map_name_with_signature("weight", attr_function_type) # For simplicity we treat "high" and "low" as directions as well self.name_mapper.map_name_with_signature("high", rdir_type) self.name_mapper.map_name_with_signature("low", rdir_type) self.name_mapper.map_name_with_signature("and", and_function_type) self.curried_functions = { attr_function_type: 2, infer_function_type: 3, and_function_type: 2, } else: raise Exception(f"Unknown LF syntax specification: {syntax}") self.name_mapper.map_name_with_signature("higher", rdir_type) self.name_mapper.map_name_with_signature("lower", rdir_type) self.name_mapper.map_name_with_signature("world1", world_type) self.name_mapper.map_name_with_signature("world2", world_type) # Hack to expose types self.world_type = world_type self.attr_function_type = attr_function_type self.var_type = var_type self.starting_types = {num_type}
@overrides def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]: # There's no ANY_TYPE in here, so we don't need to do any substitution. return [self] # All constants default to ``EntityType`` in NLTK. For domains where constants of different types # appear in the logical forms, we have a way of specifying ``constant_type_prefixes`` and passing # them to the constructor of ``World``. However, in the NLVR language we defined, we see constants # of just one type, number. So we let them default to ``EntityType``. NUM_TYPE = EntityType() BOX_TYPE = NamedBasicType("BOX") OBJECT_TYPE = NamedBasicType("OBJECT") COLOR_TYPE = NamedBasicType("COLOR") SHAPE_TYPE = NamedBasicType("SHAPE") OBJECT_FILTER_TYPE = ComplexType(OBJECT_TYPE, OBJECT_TYPE) NEGATE_FILTER_TYPE = NegateFilterType(ComplexType(OBJECT_TYPE, OBJECT_TYPE), ComplexType(OBJECT_TYPE, OBJECT_TYPE)) BOX_MEMBERSHIP_TYPE = ComplexType(BOX_TYPE, OBJECT_TYPE) BOX_COLOR_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(COLOR_TYPE, BOX_TYPE)) BOX_SHAPE_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(SHAPE_TYPE, BOX_TYPE)) BOX_COUNT_FILTER_TYPE = ComplexType(BOX_TYPE, ComplexType(NUM_TYPE, BOX_TYPE)) # This box filter returns boxes where a specified attribute is same or different BOX_ATTRIBUTE_SAME_FILTER_TYPE = ComplexType(BOX_TYPE, BOX_TYPE) ASSERT_COLOR_TYPE = ComplexType(OBJECT_TYPE, ComplexType(COLOR_TYPE, TRUTH_TYPE)) ASSERT_SHAPE_TYPE = ComplexType(OBJECT_TYPE, ComplexType(SHAPE_TYPE, TRUTH_TYPE)) ASSERT_BOX_COUNT_TYPE = ComplexType(BOX_TYPE, ComplexType(NUM_TYPE, TRUTH_TYPE)) ASSERT_OBJECT_COUNT_TYPE = ComplexType(OBJECT_TYPE, ComplexType(NUM_TYPE, TRUTH_TYPE))
def get_application_type(self, argument_type: Type) -> Type: return ComplexType(argument_type.second, argument_type.first)
@overrides def substitute_any_type(self, basic_types: Set[BasicType]) -> List[Type]: if self.first != ANY_TYPE: return [self] return [CountType(basic_type) for basic_type in basic_types] CELL_TYPE = NamedBasicType("CELL") PART_TYPE = NamedBasicType("PART") ROW_TYPE = NamedBasicType("ROW") DATE_TYPE = NamedBasicType("DATE") NUMBER_TYPE = NamedBasicType("NUMBER") BASIC_TYPES = {CELL_TYPE, PART_TYPE, ROW_TYPE, DATE_TYPE, NUMBER_TYPE} # Functions like fb:row.row.year. COLUMN_TYPE = ComplexType(CELL_TYPE, ROW_TYPE) # fb:cell.cell.part PART_TO_CELL_TYPE = ComplexType(PART_TYPE, CELL_TYPE) # fb:cell.cell.date DATE_TO_CELL_TYPE = ComplexType(DATE_TYPE, CELL_TYPE) # fb:cell.cell.number NUM_TO_CELL_TYPE = ComplexType(NUMBER_TYPE, CELL_TYPE) # number NUMBER_FUNCTION_TYPE = ComplexType(NUMBER_TYPE, NUMBER_TYPE) # date (Signature: <e,<e,<e,d>>>; Example: (date 1982 -1 -1)) DATE_FUNCTION_TYPE = ComplexType(NUMBER_TYPE, ComplexType(NUMBER_TYPE, ComplexType(NUMBER_TYPE, DATE_TYPE))) # Unary numerical operations: max, min, >, <, sum etc. UNARY_DATE_NUM_OP_TYPE = UnaryOpType(allowed_substitutions={DATE_TYPE, NUMBER_TYPE}, signature='<nd,nd>') UNARY_NUM_OP_TYPE = ComplexType(NUMBER_TYPE, NUMBER_TYPE)
GENERIC_COLUMN_TYPE = MultiMatchNamedBasicType("GCOLUMN", [STRING_COLUMN_TYPE, DATE_COLUMN_TYPE, NUMBER_COLUMN_TYPE]) COMPARABLE_COLUMN_TYPE = MultiMatchNamedBasicType("CCOLUMN", [NUMBER_COLUMN_TYPE, DATE_COLUMN_TYPE]) NUMBER_TYPE = NamedBasicType("NUMBER") DATE_TYPE = NamedBasicType("DATE") STRING_TYPE = NamedBasicType("STRING") # We start with just the generic column type, and add the specific column types to the set only if # we see the corresponding types in the table. BASIC_TYPES = {ROW_TYPE, GENERIC_COLUMN_TYPE, NUMBER_TYPE, DATE_TYPE, STRING_TYPE} STARTING_TYPES = {NUMBER_TYPE, DATE_TYPE, STRING_TYPE} # Complex types # Type for selecting the value in a column in a set of rows. "select" and "mode" functions. SELECT_TYPE = ComplexType(ROW_TYPE, ComplexType(GENERIC_COLUMN_TYPE, STRING_TYPE)) # Type for filtering rows given a column. "argmax", "argmin" and "same_as" (select all rows with the # same value under the given column as the given row does under the given column). While "same_as" # takes any column, "argmax" and "argmin" take only comparable columns (i.e. dates or numbers). # Note that the values used for comparison in "argmax" and "argmin" can only come from column # lookups in this language. In LambdaDCS, there's a lambda function that is applied to the rows to # get the values, but here, we simply have a column name. ROW_FILTER_WITH_GENERIC_COLUMN = ComplexType(ROW_TYPE, ComplexType(GENERIC_COLUMN_TYPE, ROW_TYPE)) ROW_FILTER_WITH_COMPARABLE_COLUMN = ComplexType(ROW_TYPE, ComplexType(COMPARABLE_COLUMN_TYPE, ROW_TYPE)) # "filter_number_greater", "filter_number_equals" etc. ROW_FILTER_WITH_COLUMN_AND_NUMBER = ComplexType(ROW_TYPE, ComplexType(NUMBER_COLUMN_TYPE, ComplexType(NUMBER_TYPE, ROW_TYPE)))
def test_reverse_resolves_correctly(self): assert REVERSE_TYPE.resolve(CELL_TYPE) is None # Resolving against <?,<e,r>> should give <<r,e>,<e,r>> resolution = REVERSE_TYPE.resolve( ComplexType(ANY_TYPE, ComplexType(CELL_TYPE, ROW_TYPE))) assert resolution == ReverseType(ComplexType(ROW_TYPE, CELL_TYPE), ComplexType(CELL_TYPE, ROW_TYPE)) # Resolving against <<r,?>,<e,?>> should give <<r,e>,<e,r>> resolution = REVERSE_TYPE.resolve( ComplexType(ComplexType(ROW_TYPE, ANY_TYPE), ComplexType(CELL_TYPE, ANY_TYPE))) assert resolution == ReverseType(ComplexType(ROW_TYPE, CELL_TYPE), ComplexType(CELL_TYPE, ROW_TYPE)) # Resolving against <<r,?>,?> should give <<r,?>,<?,r>> resolution = REVERSE_TYPE.resolve( ComplexType(ComplexType(ROW_TYPE, ANY_TYPE), ANY_TYPE)) assert resolution == ReverseType(ComplexType(ROW_TYPE, ANY_TYPE), ComplexType(ANY_TYPE, ROW_TYPE)) # Resolving against <<r,?>,<?,e>> should give None resolution = REVERSE_TYPE.resolve( ComplexType(ComplexType(ROW_TYPE, ANY_TYPE), ComplexType(ANY_TYPE, CELL_TYPE))) assert resolution is None
NamedBasicType, ComplexType, NameMapper) # Basic types PAS_TYPE = NamedBasicType("PAS") # Predicate Argument Structure RELATION_TYPE = NamedBasicType("RELATION") NUMBER_TYPE = NamedBasicType("NUMBER") DATE_TYPE = NamedBasicType("DATE") STRING_TYPE = NamedBasicType("STRING") BASIC_TYPES = {PAS_TYPE, RELATION_TYPE, NUMBER_TYPE, DATE_TYPE, STRING_TYPE} STARTING_TYPES = {NUMBER_TYPE, DATE_TYPE, STRING_TYPE} # Complex types # Type for selecting the value in a column in a set of rows. "select", "mode", and "extract_entity" functions. SELECT_TYPE = ComplexType(PAS_TYPE, ComplexType(RELATION_TYPE, STRING_TYPE)) # Type for filtering structures given a relation. "argmax", "argmin" PAS_FILTER_WITH_RELATION = ComplexType(PAS_TYPE, ComplexType(RELATION_TYPE, PAS_TYPE)) # "filter_number_greater", "filter_number_equals" etc. PAS_FILTER_WITH_RELATION_AND_NUMBER = ComplexType( PAS_TYPE, ComplexType(RELATION_TYPE, ComplexType(NUMBER_TYPE, PAS_TYPE))) # "filter_date_greater", "filter_date_equals" etc. PAS_FILTER_WITH_RELATION_AND_DATE = ComplexType( PAS_TYPE, ComplexType(RELATION_TYPE, ComplexType(DATE_TYPE, PAS_TYPE))) # "filter_in" and "filter_not_in" PAS_FILTER_WITH_RELATION_AND_STRING = ComplexType(