def validate(self) -> None: model_ = metadata.SpannerMetadata.model(self._table) if not model_: raise error.SpannerError('Table {} does not exist'.format(self._table)) if self._column not in model_.fields: raise error.SpannerError('Column {} does not exist on {}'.format( self._column, self._table)) # Verify no indices exist on the column we're trying to drop num_indexed_columns = index_column.IndexColumnSchema.count( None, condition.equal_to('column_name', self._column), condition.equal_to('table_name', self._table)) if num_indexed_columns > 0: raise error.SpannerError('Column {} is indexed'.format(self._column))
def test_or(self): condition_1 = condition.equal_to("int_", 1) condition_2 = condition.equal_to("int_", 2) select_query = self.select(condition.or_([condition_1], [condition_2])) expected_sql = "((table.int_ = @int_0) OR (table.int_ = @int_1))" self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {"int_0": 1, "int_1": 2}) self.assertEqual( select_query.types(), { "int_0": field.Integer.grpc_type(), "int_1": field.Integer.grpc_type() }, )
def test_includes_subconditions_query(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) expected_sql = ( 'WHERE SmallTestModel.key = RelationshipTestModel.parent_key ' 'AND SmallTestModel.key = @key0') self.assertRegex(select_query.sql(), expected_sql)
def count_equal(cls, transaction: Optional[ spanner_transaction.Transaction] = None, **constraints: Any) -> int: """Returns the number of objects in Spanner that match the given conditions. Convenience method that generates equality conditions based on the keyword arguments. Args: transaction: The existing transaction to use, or None to start a new transaction **constraints: Each key/value pair is turned into an equality condition: the key is used as the column in the condition and the value is used as the value to compare the column against in the query. Returns: The integer result of the COUNT query """ conditions = [] for column, value in constraints.items(): if isinstance(value, list): conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) return cls.count(transaction, *conditions)
def test_includes_subconditions_query(self): select_query = self.includes("parents", condition.equal_to("key", "value")) expected_sql = ( "WHERE SmallTestModel.key = RelationshipTestModel.parent_key " "AND SmallTestModel.key = @key0") self.assertRegex(select_query.sql(), expected_sql)
def where_equal( cls, transaction: Optional[spanner_transaction.Transaction] = None, **constraints: Any ) -> List["ModelObject"]: """Retrieves objects from Spanner based on the provided constraints. Args: transaction: The existing transaction to use, or None to start a new transaction **constraints: Each key/value pair is turned into an equality condition: the key is used as the column in the condition and the value is used as the value to compare the column against in the query. Returns: A list containing all requested objects that exist in the table (can be an empty list) """ conditions = [] for column, value in constraints.items(): if isinstance(value, list): conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) return cls.where(transaction, *conditions)
def test_includes_error_on_invalid_subconditions(self, column, value, relation, foreign_key_relation): with self.assertRaises(error.ValidationError): self.includes( relation, condition.equal_to(column, value), foreign_key_relation, )
def test_query_combines_properly(self): select_query = self.select( condition.equal_to('int_', 5), condition.not_equal_to('string_array', ['foo', 'bar']), condition.limit(2), condition.order_by(('string', condition.OrderType.DESC))) expected_sql = ( 'WHERE table.int_ = @int_0 AND table.string_array != ' '@string_array1 ORDER BY table.string DESC LIMIT @limit2') self.assertEndsWith(select_query.sql(), expected_sql)
def test_query_combines_properly(self): select_query = self.select( condition.equal_to("int_", 5), condition.not_equal_to("string_array", ["foo", "bar"]), condition.limit(2), condition.order_by(("string", condition.OrderType.DESC)), ) expected_sql = ( "WHERE table.int_ = @int_0 AND table.string_array != " "@string_array1 ORDER BY table.string DESC LIMIT @limit2") self.assertEndsWith(select_query.sql(), expected_sql)
def indexes(cls) -> Dict[str, Dict[str, Any]]: """Compiles index information from index and index columns schemas.""" # ordinal_position is the position of the column in the indicated index. # Results are ordered by that so the index columns are added in the # correct order. index_column_schemas = index_column.IndexColumnSchema.where( None, condition.equal_to("table_catalog", ""), condition.equal_to("table_schema", ""), condition.order_by(("ordinal_position", condition.OrderType.ASC)), ) index_columns = collections.defaultdict(list) storing_columns = collections.defaultdict(list) for schema in index_column_schemas: key = (schema.table_name, schema.index_name) if schema.ordinal_position is not None: index_columns[key].append(schema.column_name) else: storing_columns[key].append(schema.column_name) index_schemas = index_schema.IndexSchema.where( None, condition.equal_to("table_catalog", ""), condition.equal_to("table_schema", ""), ) indexes = collections.defaultdict(dict) for schema in index_schemas: key = (schema.table_name, schema.index_name) new_index = index.Index( index_columns[key], parent=schema.parent_table_name, null_filtered=schema.is_null_filtered, unique=schema.is_unique, storing_columns=storing_columns[key], ) new_index.name = schema.index_name indexes[schema.table_name][schema.index_name] = new_index return indexes
def test_includes_subcondition_result(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) child_values, parent_values, rows = self.includes_result(related=2) result = select_query.process_results(rows)[0] self.assertLen(result.parents, 2) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) for name, value in parent_values.items(): self.assertEqual(getattr(result.parents[0], name), value)
def tables(cls) -> Dict[str, Dict[str, Any]]: """Compiles table information from column schema.""" column_data = collections.defaultdict(dict) columns = column.ColumnSchema.where(None, condition.equal_to('table_catalog', ''), condition.equal_to('table_schema', '')) for column_row in columns: new_field = field.Field( column_row.field_type(), nullable=column_row.nullable()) new_field.name = column_row.column_name new_field.position = column_row.ordinal_position column_data[column_row.table_name][column_row.column_name] = new_field table_data = collections.defaultdict(dict) tables = table.TableSchema.where(None, condition.equal_to('table_catalog', ''), condition.equal_to('table_schema', '')) for table_row in tables: name = table_row.table_name table_data[name]['parent_table'] = table_row.parent_table_name table_data[name]['fields'] = column_data[name] return table_data
def test_includes_error_on_invalid_subconditions(self, column, value): with self.assertRaises(error.ValidationError): self.includes('parent', condition.equal_to(column, value))
class ConditionTest( spanner_emulator_testlib.TestCase, parameterized.TestCase, ): def setUp(self): super().setUp() self.run_orm_migrations( os.path.join( os.path.dirname(os.path.abspath(__file__)), 'migrations_for_emulator_test', )) @parameterized.parameters( (True, type_pb2.Type(code=type_pb2.BOOL)), (0, type_pb2.Type(code=type_pb2.INT64)), (0.0, type_pb2.Type(code=type_pb2.FLOAT64)), ( datetime_helpers.DatetimeWithNanoseconds(2021, 1, 5), type_pb2.Type(code=type_pb2.TIMESTAMP), ), (datetime.datetime(2021, 1, 5), type_pb2.Type(code=type_pb2.TIMESTAMP)), (datetime.date(2021, 1, 5), type_pb2.Type(code=type_pb2.DATE)), (b'\x01', type_pb2.Type(code=type_pb2.BYTES)), ('foo', type_pb2.Type(code=type_pb2.STRING)), (decimal.Decimal('1.23'), type_pb2.Type(code=type_pb2.NUMERIC)), ( (0, 1), type_pb2.Type( code=type_pb2.ARRAY, array_element_type=type_pb2.Type(code=type_pb2.INT64), ), ), ( ['a', None, 'b'], type_pb2.Type( code=type_pb2.ARRAY, array_element_type=type_pb2.Type(code=type_pb2.STRING), ), ), ) def test_param_from_value(self, value, expected_type): param = condition.Param.from_value(value) self.assertEqual(expected_type, param.type) # Test that the value and inferred type are compatible. This will raise an # exception if they're not. self.assertEmpty( models.SmallTestModel.where( condition.ArbitraryCondition( '$param IS NULL', dict(param=param), segment=condition.Segment.WHERE, ))) @parameterized.parameters( (None, 'Cannot infer type of None'), ((0, 'some-string'), 'elements of exactly one type'), ((0, 'some-string', None), 'elements of exactly one type'), (object(), 'Unknown type'), ) def test_param_from_value_error(self, value, error_regex): with self.assertRaisesRegex(TypeError, error_regex): condition.Param.from_value(value) @parameterized.named_parameters( ( 'bytes', condition.ArbitraryCondition( '$param = b"\x01\x02"', dict(param=condition.Param.from_value(b'\x01\x02')), segment=condition.Segment.WHERE, ), ), ( 'array_of_bytes', condition.ArbitraryCondition( '${param}[OFFSET(0)] = b"\x01\x02"', dict(param=condition.Param.from_value([b'\x01\x02'])), segment=condition.Segment.WHERE, ), ), ( 'array_of_bytes_and_null', condition.ArbitraryCondition( '${param}[OFFSET(0)] IS NULL', dict(param=condition.Param.from_value((None, b'\x01\x02'))), segment=condition.Segment.WHERE, ), ), ) def test_param_from_value_correctly_encodes(self, tautology): test_model = models.SmallTestModel( dict( key='some-key', value_1='some-value', value_2='other-value', )) test_model.save() self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology)) @parameterized.named_parameters( ( 'minimal', condition.ArbitraryCondition( 'FALSE', segment=condition.Segment.WHERE, ), {}, {}, 'FALSE', (), ), ( 'full', condition.ArbitraryCondition( '$key = IF($true_param, ${key_param}, $value_1)', dict( key=models.SmallTestModel.key, true_param=condition.Param.from_value(True), key_param=condition.Param.from_value('some-key'), value_1=condition.Column('value_1'), ), segment=condition.Segment.WHERE, ), dict( true_param0=True, key_param0='some-key', ), dict( true_param0=type_pb2.Type(code=type_pb2.BOOL), key_param0=type_pb2.Type(code=type_pb2.STRING), ), ('SmallTestModel.key = ' 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), ('some-key',), ), ) def test_arbitrary_condition( self, condition_, expected_params, expected_types, expected_sql, expected_row_keys, ): models.SmallTestModel( dict( key='some-key', value_1='some-value', value_2='other-value', )).save() rows = models.SmallTestModel.where(condition_) self.assertEqual(expected_params, condition_.params()) self.assertEqual(expected_types, condition_.types()) self.assertEqual(expected_sql, condition_.sql()) self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) @parameterized.named_parameters( ('key_not_found', '$not_found', KeyError, 'not_found'), ('invalid_template', '$', ValueError, 'Invalid placeholder'), ) def test_arbitrary_condition_template_error( self, template, error_class, error_regex, ): with self.assertRaisesRegex(error_class, error_regex): condition.ArbitraryCondition(template, segment=condition.Segment.WHERE) @parameterized.named_parameters( ( 'field_from_wrong_model', models.ChildTestModel.key, 'does not belong to the Model', ), ( 'column_not_found', condition.Column('not_found'), 'does not exist in the Model', ), ) def test_arbitrary_condition_validation_error( self, substitution, error_regex, ): condition_ = condition.ArbitraryCondition( '$substitution', dict(substitution=substitution), segment=condition.Segment.WHERE, ) with self.assertRaisesRegex(error.ValidationError, error_regex): models.SmallTestModel.where(condition_) @parameterized.named_parameters( ( 'empty_or', condition.OrCondition(), {}, {}, 'FALSE', '', ), ( 'empty_and', condition.OrCondition([]), {}, {}, '(TRUE)', 'ab', ), ( 'single', condition.OrCondition( [condition.equal_to(models.SmallTestModel.key, 'a')]), dict(key0='a'), dict(key0=type_pb2.Type(code=type_pb2.STRING)), '((SmallTestModel.key = @key0))', 'a', ), ( 'multiple', condition.OrCondition( [ condition.equal_to(models.SmallTestModel.key, 'a'), condition.equal_to(models.SmallTestModel.value_1, 'a'), ], [ condition.equal_to(models.SmallTestModel.key, 'b'), condition.equal_to(models.SmallTestModel.value_1, 'b'), ], ), dict( key0='a', value_11='a', key2='b', value_13='b', ), dict( key0=type_pb2.Type(code=type_pb2.STRING), value_11=type_pb2.Type(code=type_pb2.STRING), key2=type_pb2.Type(code=type_pb2.STRING), value_13=type_pb2.Type(code=type_pb2.STRING), ), ('(' '(SmallTestModel.key = @key0 AND SmallTestModel.value_1 = @value_11)' ' OR ' '(SmallTestModel.key = @key2 AND SmallTestModel.value_1 = @value_13)' ')'), 'ab', ), ) def test_or_condition( self, condition_, expected_params, expected_types, expected_sql, expected_row_keys, ): models.SmallTestModel(dict(key='a', value_1='a', value_2='a')).save() models.SmallTestModel(dict(key='b', value_1='b', value_2='b')).save() rows = models.SmallTestModel.where(condition_) self.assertEqual(expected_params, condition_.params()) self.assertEqual(expected_types, condition_.types()) self.assertEqual(expected_sql, condition_.sql()) self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) @parameterized.parameters( ('ABCD', 'BC', True), ('ABCD', 'bc', False), ('ABCD', 'CB', False), (b'ABCD', b'BC', True), (b'ABCD', b'bc', False), (b'ABCD', b'CB', False), ('ABCD', 'BC', True, dict(case_sensitive=False)), ('ABCD', 'bc', True, dict(case_sensitive=False)), ('ABCD', 'CB', False, dict(case_sensitive=False)), (b'ABCD', b'BC', True, dict(case_sensitive=False)), (b'ABCD', b'bc', True, dict(case_sensitive=False)), (b'ABCD', b'CB', False, dict(case_sensitive=False)), ) def test_contains( self, haystack, needle, expect_results, kwargs={}, ): test_model = models.SmallTestModel(dict(key='a', value_1='a', value_2='a')) test_model.save() self.assertCountEqual( ((test_model,) if expect_results else ()), models.SmallTestModel.where( spanner_orm.contains( condition.Param.from_value(haystack), condition.Param.from_value(needle), **kwargs, )), )