Пример #1
0
  def validate(self) -> None:
    model_ = metadata.SpannerMetadata.model(self._table)
    if not model_:
      raise error.SpannerError('Table {} does not exist'.format(self._table))

    if self._column not in model_.fields:
      raise error.SpannerError('Column {} does not exist on {}'.format(
          self._column, self._table))

    # Verify no indices exist on the column we're trying to drop
    num_indexed_columns = index_column.IndexColumnSchema.count(
        None, condition.equal_to('column_name', self._column),
        condition.equal_to('table_name', self._table))
    if num_indexed_columns > 0:
      raise error.SpannerError('Column {} is indexed'.format(self._column))
Пример #2
0
    def test_or(self):
        condition_1 = condition.equal_to("int_", 1)
        condition_2 = condition.equal_to("int_", 2)
        select_query = self.select(condition.or_([condition_1], [condition_2]))

        expected_sql = "((table.int_ = @int_0) OR (table.int_ = @int_1))"
        self.assertEndsWith(select_query.sql(), expected_sql)
        self.assertEqual(select_query.parameters(), {"int_0": 1, "int_1": 2})
        self.assertEqual(
            select_query.types(),
            {
                "int_0": field.Integer.grpc_type(),
                "int_1": field.Integer.grpc_type()
            },
        )
Пример #3
0
 def test_includes_subconditions_query(self):
     select_query = self.includes('parents',
                                  condition.equal_to('key', 'value'))
     expected_sql = (
         'WHERE SmallTestModel.key = RelationshipTestModel.parent_key '
         'AND SmallTestModel.key = @key0')
     self.assertRegex(select_query.sql(), expected_sql)
Пример #4
0
    def count_equal(cls,
                    transaction: Optional[
                        spanner_transaction.Transaction] = None,
                    **constraints: Any) -> int:
        """Returns the number of objects in Spanner that match the given conditions.

    Convenience method that generates equality conditions based on the keyword
    arguments.

    Args:
      transaction: The existing transaction to use, or None to start a new
        transaction
      **constraints: Each key/value pair is turned into an equality condition:
        the key is used as the column in the condition and the value is used as
        the value to compare the column against in the query.

    Returns:
      The integer result of the COUNT query
    """
        conditions = []
        for column, value in constraints.items():
            if isinstance(value, list):
                conditions.append(condition.in_list(column, value))
            else:
                conditions.append(condition.equal_to(column, value))
        return cls.count(transaction, *conditions)
Пример #5
0
 def test_includes_subconditions_query(self):
     select_query = self.includes("parents",
                                  condition.equal_to("key", "value"))
     expected_sql = (
         "WHERE SmallTestModel.key = RelationshipTestModel.parent_key "
         "AND SmallTestModel.key = @key0")
     self.assertRegex(select_query.sql(), expected_sql)
Пример #6
0
    def where_equal(
        cls,
        transaction: Optional[spanner_transaction.Transaction] = None,
        **constraints: Any
    ) -> List["ModelObject"]:
        """Retrieves objects from Spanner based on the provided constraints.

        Args:
          transaction: The existing transaction to use, or None to start a new
            transaction
          **constraints: Each key/value pair is turned into an equality condition:
            the key is used as the column in the condition and the value is used as
            the value to compare the column against in the query.

        Returns:
          A list containing all requested objects that exist in the table (can be
          an empty list)
        """
        conditions = []
        for column, value in constraints.items():
            if isinstance(value, list):
                conditions.append(condition.in_list(column, value))
            else:
                conditions.append(condition.equal_to(column, value))
        return cls.where(transaction, *conditions)
Пример #7
0
 def test_includes_error_on_invalid_subconditions(self, column, value,
                                                  relation,
                                                  foreign_key_relation):
     with self.assertRaises(error.ValidationError):
         self.includes(
             relation,
             condition.equal_to(column, value),
             foreign_key_relation,
         )
Пример #8
0
 def test_query_combines_properly(self):
     select_query = self.select(
         condition.equal_to('int_', 5),
         condition.not_equal_to('string_array', ['foo', 'bar']),
         condition.limit(2),
         condition.order_by(('string', condition.OrderType.DESC)))
     expected_sql = (
         'WHERE table.int_ = @int_0 AND table.string_array != '
         '@string_array1 ORDER BY table.string DESC LIMIT @limit2')
     self.assertEndsWith(select_query.sql(), expected_sql)
Пример #9
0
 def test_query_combines_properly(self):
     select_query = self.select(
         condition.equal_to("int_", 5),
         condition.not_equal_to("string_array", ["foo", "bar"]),
         condition.limit(2),
         condition.order_by(("string", condition.OrderType.DESC)),
     )
     expected_sql = (
         "WHERE table.int_ = @int_0 AND table.string_array != "
         "@string_array1 ORDER BY table.string DESC LIMIT @limit2")
     self.assertEndsWith(select_query.sql(), expected_sql)
Пример #10
0
    def indexes(cls) -> Dict[str, Dict[str, Any]]:
        """Compiles index information from index and index columns schemas."""
        # ordinal_position is the position of the column in the indicated index.
        # Results are ordered by that so the index columns are added in the
        # correct order.
        index_column_schemas = index_column.IndexColumnSchema.where(
            None,
            condition.equal_to("table_catalog", ""),
            condition.equal_to("table_schema", ""),
            condition.order_by(("ordinal_position", condition.OrderType.ASC)),
        )

        index_columns = collections.defaultdict(list)
        storing_columns = collections.defaultdict(list)
        for schema in index_column_schemas:
            key = (schema.table_name, schema.index_name)
            if schema.ordinal_position is not None:
                index_columns[key].append(schema.column_name)
            else:
                storing_columns[key].append(schema.column_name)

        index_schemas = index_schema.IndexSchema.where(
            None,
            condition.equal_to("table_catalog", ""),
            condition.equal_to("table_schema", ""),
        )
        indexes = collections.defaultdict(dict)
        for schema in index_schemas:
            key = (schema.table_name, schema.index_name)
            new_index = index.Index(
                index_columns[key],
                parent=schema.parent_table_name,
                null_filtered=schema.is_null_filtered,
                unique=schema.is_unique,
                storing_columns=storing_columns[key],
            )
            new_index.name = schema.index_name
            indexes[schema.table_name][schema.index_name] = new_index
        return indexes
Пример #11
0
    def test_includes_subcondition_result(self):
        select_query = self.includes('parents',
                                     condition.equal_to('key', 'value'))

        child_values, parent_values, rows = self.includes_result(related=2)
        result = select_query.process_results(rows)[0]

        self.assertLen(result.parents, 2)
        for name, value in child_values.items():
            self.assertEqual(getattr(result, name), value)

        for name, value in parent_values.items():
            self.assertEqual(getattr(result.parents[0], name), value)
Пример #12
0
  def tables(cls) -> Dict[str, Dict[str, Any]]:
    """Compiles table information from column schema."""
    column_data = collections.defaultdict(dict)
    columns = column.ColumnSchema.where(None,
                                        condition.equal_to('table_catalog', ''),
                                        condition.equal_to('table_schema', ''))
    for column_row in columns:
      new_field = field.Field(
          column_row.field_type(), nullable=column_row.nullable())
      new_field.name = column_row.column_name
      new_field.position = column_row.ordinal_position
      column_data[column_row.table_name][column_row.column_name] = new_field

    table_data = collections.defaultdict(dict)
    tables = table.TableSchema.where(None,
                                     condition.equal_to('table_catalog', ''),
                                     condition.equal_to('table_schema', ''))
    for table_row in tables:
      name = table_row.table_name
      table_data[name]['parent_table'] = table_row.parent_table_name
      table_data[name]['fields'] = column_data[name]
    return table_data
Пример #13
0
 def test_includes_error_on_invalid_subconditions(self, column, value):
     with self.assertRaises(error.ValidationError):
         self.includes('parent', condition.equal_to(column, value))
Пример #14
0
class ConditionTest(
    spanner_emulator_testlib.TestCase,
    parameterized.TestCase,
):

  def setUp(self):
    super().setUp()
    self.run_orm_migrations(
        os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            'migrations_for_emulator_test',
        ))

  @parameterized.parameters(
      (True, type_pb2.Type(code=type_pb2.BOOL)),
      (0, type_pb2.Type(code=type_pb2.INT64)),
      (0.0, type_pb2.Type(code=type_pb2.FLOAT64)),
      (
          datetime_helpers.DatetimeWithNanoseconds(2021, 1, 5),
          type_pb2.Type(code=type_pb2.TIMESTAMP),
      ),
      (datetime.datetime(2021, 1, 5), type_pb2.Type(code=type_pb2.TIMESTAMP)),
      (datetime.date(2021, 1, 5), type_pb2.Type(code=type_pb2.DATE)),
      (b'\x01', type_pb2.Type(code=type_pb2.BYTES)),
      ('foo', type_pb2.Type(code=type_pb2.STRING)),
      (decimal.Decimal('1.23'), type_pb2.Type(code=type_pb2.NUMERIC)),
      (
          (0, 1),
          type_pb2.Type(
              code=type_pb2.ARRAY,
              array_element_type=type_pb2.Type(code=type_pb2.INT64),
          ),
      ),
      (
          ['a', None, 'b'],
          type_pb2.Type(
              code=type_pb2.ARRAY,
              array_element_type=type_pb2.Type(code=type_pb2.STRING),
          ),
      ),
  )
  def test_param_from_value(self, value, expected_type):
    param = condition.Param.from_value(value)
    self.assertEqual(expected_type, param.type)
    # Test that the value and inferred type are compatible. This will raise an
    # exception if they're not.
    self.assertEmpty(
        models.SmallTestModel.where(
            condition.ArbitraryCondition(
                '$param IS NULL',
                dict(param=param),
                segment=condition.Segment.WHERE,
            )))

  @parameterized.parameters(
      (None, 'Cannot infer type of None'),
      ((0, 'some-string'), 'elements of exactly one type'),
      ((0, 'some-string', None), 'elements of exactly one type'),
      (object(), 'Unknown type'),
  )
  def test_param_from_value_error(self, value, error_regex):
    with self.assertRaisesRegex(TypeError, error_regex):
      condition.Param.from_value(value)

  @parameterized.named_parameters(
      (
          'bytes',
          condition.ArbitraryCondition(
              '$param = b"\x01\x02"',
              dict(param=condition.Param.from_value(b'\x01\x02')),
              segment=condition.Segment.WHERE,
          ),
      ),
      (
          'array_of_bytes',
          condition.ArbitraryCondition(
              '${param}[OFFSET(0)] = b"\x01\x02"',
              dict(param=condition.Param.from_value([b'\x01\x02'])),
              segment=condition.Segment.WHERE,
          ),
      ),
      (
          'array_of_bytes_and_null',
          condition.ArbitraryCondition(
              '${param}[OFFSET(0)] IS NULL',
              dict(param=condition.Param.from_value((None, b'\x01\x02'))),
              segment=condition.Segment.WHERE,
          ),
      ),
  )
  def test_param_from_value_correctly_encodes(self, tautology):
    test_model = models.SmallTestModel(
        dict(
            key='some-key',
            value_1='some-value',
            value_2='other-value',
        ))
    test_model.save()
    self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology))

  @parameterized.named_parameters(
      (
          'minimal',
          condition.ArbitraryCondition(
              'FALSE',
              segment=condition.Segment.WHERE,
          ),
          {},
          {},
          'FALSE',
          (),
      ),
      (
          'full',
          condition.ArbitraryCondition(
              '$key = IF($true_param, ${key_param}, $value_1)',
              dict(
                  key=models.SmallTestModel.key,
                  true_param=condition.Param.from_value(True),
                  key_param=condition.Param.from_value('some-key'),
                  value_1=condition.Column('value_1'),
              ),
              segment=condition.Segment.WHERE,
          ),
          dict(
              true_param0=True,
              key_param0='some-key',
          ),
          dict(
              true_param0=type_pb2.Type(code=type_pb2.BOOL),
              key_param0=type_pb2.Type(code=type_pb2.STRING),
          ),
          ('SmallTestModel.key = '
           'IF(@true_param0, @key_param0, SmallTestModel.value_1)'),
          ('some-key',),
      ),
  )
  def test_arbitrary_condition(
      self,
      condition_,
      expected_params,
      expected_types,
      expected_sql,
      expected_row_keys,
  ):
    models.SmallTestModel(
        dict(
            key='some-key',
            value_1='some-value',
            value_2='other-value',
        )).save()
    rows = models.SmallTestModel.where(condition_)
    self.assertEqual(expected_params, condition_.params())
    self.assertEqual(expected_types, condition_.types())
    self.assertEqual(expected_sql, condition_.sql())
    self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows))

  @parameterized.named_parameters(
      ('key_not_found', '$not_found', KeyError, 'not_found'),
      ('invalid_template', '$', ValueError, 'Invalid placeholder'),
  )
  def test_arbitrary_condition_template_error(
      self,
      template,
      error_class,
      error_regex,
  ):
    with self.assertRaisesRegex(error_class, error_regex):
      condition.ArbitraryCondition(template, segment=condition.Segment.WHERE)

  @parameterized.named_parameters(
      (
          'field_from_wrong_model',
          models.ChildTestModel.key,
          'does not belong to the Model',
      ),
      (
          'column_not_found',
          condition.Column('not_found'),
          'does not exist in the Model',
      ),
  )
  def test_arbitrary_condition_validation_error(
      self,
      substitution,
      error_regex,
  ):
    condition_ = condition.ArbitraryCondition(
        '$substitution',
        dict(substitution=substitution),
        segment=condition.Segment.WHERE,
    )
    with self.assertRaisesRegex(error.ValidationError, error_regex):
      models.SmallTestModel.where(condition_)

  @parameterized.named_parameters(
      (
          'empty_or',
          condition.OrCondition(),
          {},
          {},
          'FALSE',
          '',
      ),
      (
          'empty_and',
          condition.OrCondition([]),
          {},
          {},
          '(TRUE)',
          'ab',
      ),
      (
          'single',
          condition.OrCondition(
              [condition.equal_to(models.SmallTestModel.key, 'a')]),
          dict(key0='a'),
          dict(key0=type_pb2.Type(code=type_pb2.STRING)),
          '((SmallTestModel.key = @key0))',
          'a',
      ),
      (
          'multiple',
          condition.OrCondition(
              [
                  condition.equal_to(models.SmallTestModel.key, 'a'),
                  condition.equal_to(models.SmallTestModel.value_1, 'a'),
              ],
              [
                  condition.equal_to(models.SmallTestModel.key, 'b'),
                  condition.equal_to(models.SmallTestModel.value_1, 'b'),
              ],
          ),
          dict(
              key0='a',
              value_11='a',
              key2='b',
              value_13='b',
          ),
          dict(
              key0=type_pb2.Type(code=type_pb2.STRING),
              value_11=type_pb2.Type(code=type_pb2.STRING),
              key2=type_pb2.Type(code=type_pb2.STRING),
              value_13=type_pb2.Type(code=type_pb2.STRING),
          ),
          ('('
           '(SmallTestModel.key = @key0 AND SmallTestModel.value_1 = @value_11)'
           ' OR '
           '(SmallTestModel.key = @key2 AND SmallTestModel.value_1 = @value_13)'
           ')'),
          'ab',
      ),
  )
  def test_or_condition(
      self,
      condition_,
      expected_params,
      expected_types,
      expected_sql,
      expected_row_keys,
  ):
    models.SmallTestModel(dict(key='a', value_1='a', value_2='a')).save()
    models.SmallTestModel(dict(key='b', value_1='b', value_2='b')).save()
    rows = models.SmallTestModel.where(condition_)
    self.assertEqual(expected_params, condition_.params())
    self.assertEqual(expected_types, condition_.types())
    self.assertEqual(expected_sql, condition_.sql())
    self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows))

  @parameterized.parameters(
      ('ABCD', 'BC', True),
      ('ABCD', 'bc', False),
      ('ABCD', 'CB', False),
      (b'ABCD', b'BC', True),
      (b'ABCD', b'bc', False),
      (b'ABCD', b'CB', False),
      ('ABCD', 'BC', True, dict(case_sensitive=False)),
      ('ABCD', 'bc', True, dict(case_sensitive=False)),
      ('ABCD', 'CB', False, dict(case_sensitive=False)),
      (b'ABCD', b'BC', True, dict(case_sensitive=False)),
      (b'ABCD', b'bc', True, dict(case_sensitive=False)),
      (b'ABCD', b'CB', False, dict(case_sensitive=False)),
  )
  def test_contains(
      self,
      haystack,
      needle,
      expect_results,
      kwargs={},
  ):
    test_model = models.SmallTestModel(dict(key='a', value_1='a', value_2='a'))
    test_model.save()
    self.assertCountEqual(
        ((test_model,) if expect_results else ()),
        models.SmallTestModel.where(
            spanner_orm.contains(
                condition.Param.from_value(haystack),
                condition.Param.from_value(needle),
                **kwargs,
            )),
    )