Ejemplo n.º 1
0
    def send_body(self, f, protocol_version):
        write_longstring(f, self.query)

        flags = 0x00

        if self.keyspace is not None:
            if ProtocolVersion.uses_keyspace_flag(protocol_version):
                flags |= _PREPARED_WITH_KEYSPACE_FLAG
            else:
                raise UnsupportedOperation(
                    "Keyspaces may only be set on queries with protocol version "
                    "5 or higher. Consider setting Cluster.protocol_version to 5."
                )

        if ProtocolVersion.uses_prepare_flags(protocol_version):
            write_uint(f, flags)
        else:
            # checks above should prevent this, but just to be safe...
            if flags:
                raise UnsupportedOperation(
                    "Attempted to set flags with value {flags:0=#8x} on"
                    "protocol version {pv}, which doesn't support flags"
                    "in prepared statements."
                    "Consider setting Cluster.protocol_version to 5."
                    "".format(flags=flags, pv=protocol_version))

        if ProtocolVersion.uses_keyspace_flag(protocol_version):
            if self.keyspace:
                write_string(f, self.keyspace)
Ejemplo n.º 2
0
    def send_body(self, f, protocol_version):
        write_longstring(f, self.query)

        flags = 0x00

        if self.keyspace is not None:
            if ProtocolVersion.uses_keyspace_flag(protocol_version):
                flags |= _PREPARED_WITH_KEYSPACE_FLAG
            else:
                raise UnsupportedOperation(
                    "Keyspaces may only be set on queries with protocol version "
                    "5 or higher. Consider setting Cluster.protocol_version to 5.")

        if ProtocolVersion.uses_prepare_flags(protocol_version):
            write_uint(f, flags)
        else:
            # checks above should prevent this, but just to be safe...
            if flags:
                raise UnsupportedOperation(
                    "Attempted to set flags with value {flags:0=#8x} on"
                    "protocol version {pv}, which doesn't support flags"
                    "in prepared statements."
                    "Consider setting Cluster.protocol_version to 5."
                    "".format(flags=flags, pv=protocol_version))

        if ProtocolVersion.uses_keyspace_flag(protocol_version):
            if self.keyspace:
                write_string(f, self.keyspace)
Ejemplo n.º 3
0
    def send_body(self, f, protocol_version):
        write_string(f, self.query_id)
        if ProtocolVersion.uses_prepared_metadata(protocol_version):
            write_string(f, self.result_metadata_id)
        if protocol_version == 1:
            if self.serial_consistency_level:
                raise UnsupportedOperation(
                    "Serial consistency levels require the use of protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
                    "to support serial consistency levels.")
            if self.fetch_size or self.paging_state:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2."
                )
            write_short(f, len(self.query_params))
            for param in self.query_params:
                write_value(f, param)
            write_consistency_level(f, self.consistency_level)
        else:
            write_consistency_level(f, self.consistency_level)
            flags = _VALUES_FLAG
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.fetch_size:
                flags |= _PAGE_SIZE_FLAG
            if self.paging_state:
                flags |= _WITH_PAGING_STATE_FLAG
            if self.timestamp is not None:
                if protocol_version >= 3:
                    flags |= _PROTOCOL_TIMESTAMP
                else:
                    raise UnsupportedOperation(
                        "Protocol-level timestamps may only be used with protocol version "
                        "3 or higher. Consider setting Cluster.protocol_version to 3."
                    )
            if self.skip_meta:
                flags |= _SKIP_METADATA_FLAG

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_uint(f, flags)
            else:
                write_byte(f, flags)

            write_short(f, len(self.query_params))
            for param in self.query_params:
                write_value(f, param)
            if self.fetch_size:
                write_int(f, self.fetch_size)
            if self.paging_state:
                write_longstring(f, self.paging_state)
            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)
Ejemplo n.º 4
0
    def send_body(self, f, protocol_version):
        write_string(f, self.query_id)
        if ProtocolVersion.uses_prepared_metadata(protocol_version):
            write_string(f, self.result_metadata_id)
        if protocol_version == 1:
            if self.serial_consistency_level:
                raise UnsupportedOperation(
                    "Serial consistency levels require the use of protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
                    "to support serial consistency levels.")
            if self.fetch_size or self.paging_state:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2.")
            write_short(f, len(self.query_params))
            for param in self.query_params:
                write_value(f, param)
            write_consistency_level(f, self.consistency_level)
        else:
            write_consistency_level(f, self.consistency_level)
            flags = _VALUES_FLAG
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.fetch_size:
                flags |= _PAGE_SIZE_FLAG
            if self.paging_state:
                flags |= _WITH_PAGING_STATE_FLAG
            if self.timestamp is not None:
                if protocol_version >= 3:
                    flags |= _PROTOCOL_TIMESTAMP
                else:
                    raise UnsupportedOperation(
                        "Protocol-level timestamps may only be used with protocol version "
                        "3 or higher. Consider setting Cluster.protocol_version to 3.")
            if self.skip_meta:
                flags |= _SKIP_METADATA_FLAG

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_uint(f, flags)
            else:
                write_byte(f, flags)

            write_short(f, len(self.query_params))
            for param in self.query_params:
                write_value(f, param)
            if self.fetch_size:
                write_int(f, self.fetch_size)
            if self.paging_state:
                write_longstring(f, self.paging_state)
            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)
Ejemplo n.º 5
0
    def decode_message(cls, protocol_version, user_type_map, stream_id, flags,
                       opcode, body, decompressor, result_metadata):
        """
        Decodes a native protocol message body

        :param protocol_version: version to use decoding contents
        :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
        :param stream_id: native protocol stream id from the frame header
        :param flags: native protocol flags bitmap from the header
        :param opcode: native protocol opcode from the header
        :param body: frame body
        :param decompressor: optional decompression function to inflate the body
        :return: a message decoded from the body and frame attributes
        """
        if (not ProtocolVersion.has_checksumming_support(protocol_version)
                and flags & COMPRESSED_FLAG):
            if decompressor is None:
                raise RuntimeError(
                    "No de-compressor available for compressed frame!")
            body = decompressor(body)
            flags ^= COMPRESSED_FLAG

        body = io.BytesIO(body)
        if flags & TRACING_FLAG:
            trace_id = UUID(bytes=body.read(16))
            flags ^= TRACING_FLAG
        else:
            trace_id = None

        if flags & WARNING_FLAG:
            warnings = read_stringlist(body)
            flags ^= WARNING_FLAG
        else:
            warnings = None

        if flags & CUSTOM_PAYLOAD_FLAG:
            custom_payload = read_bytesmap(body)
            flags ^= CUSTOM_PAYLOAD_FLAG
        else:
            custom_payload = None

        flags &= USE_BETA_MASK  # will only be set if we asserted it in connection estabishment

        if flags:
            log.warning(
                "Unknown protocol flags set: %02x. May cause problems.", flags)

        msg_class = cls.message_types_by_opcode[opcode]
        msg = msg_class.recv_body(body, protocol_version, user_type_map,
                                  result_metadata)
        msg.stream_id = stream_id
        msg.trace_id = trace_id
        msg.custom_payload = custom_payload
        msg.warnings = warnings

        if msg.warnings:
            for w in msg.warnings:
                log.warning("Server warning: %s", w)

        return msg
Ejemplo n.º 6
0
    def send_body(self, f, protocol_version):
        write_byte(f, self.batch_type.value)
        write_short(f, len(self.queries))
        for prepared, string_or_query_id, params in self.queries:
            if not prepared:
                write_byte(f, 0)
                write_longstring(f, string_or_query_id)
            else:
                write_byte(f, 1)
                write_short(f, len(string_or_query_id))
                f.write(string_or_query_id)
            write_short(f, len(params))
            for param in params:
                write_value(f, param)

        write_consistency_level(f, self.consistency_level)
        if protocol_version >= 3:
            flags = 0
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.timestamp is not None:
                flags |= _PROTOCOL_TIMESTAMP

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_int(f, flags)
            else:
                write_byte(f, flags)

            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)
Ejemplo n.º 7
0
    def _perform_cql_statement(self,
                               text,
                               consistency_level,
                               expected_exception,
                               session=None):
        """
        Simple helper method to preform cql statements and check for expected exception
        @param text CQl statement to execute
        @param consistency_level Consistency level at which it is to be executed
        @param expected_exception Exception expected to be throw or none
        """
        if session is None:
            session = self.session
        statement = SimpleStatement(text)
        statement.consistency_level = consistency_level

        if expected_exception is None:
            self.execute_helper(session, statement)
        else:
            with self.assertRaises(expected_exception) as cm:
                self.execute_helper(session, statement)
            if ProtocolVersion.uses_error_code_map(PROTOCOL_VERSION):
                if isinstance(cm.exception, ReadFailure):
                    self.assertEqual(
                        list(cm.exception.error_code_map.values())[0], 1)
                if isinstance(cm.exception, WriteFailure):
                    self.assertEqual(
                        list(cm.exception.error_code_map.values())[0], 0)
Ejemplo n.º 8
0
    def send_body(self, f, protocol_version):
        write_byte(f, self.batch_type.value)
        write_short(f, len(self.queries))
        for prepared, string_or_query_id, params in self.queries:
            if not prepared:
                write_byte(f, 0)
                write_longstring(f, string_or_query_id)
            else:
                write_byte(f, 1)
                write_short(f, len(string_or_query_id))
                f.write(string_or_query_id)
            write_short(f, len(params))
            for param in params:
                write_value(f, param)

        write_consistency_level(f, self.consistency_level)
        if protocol_version >= 3:
            flags = 0
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.timestamp is not None:
                flags |= _PROTOCOL_TIMESTAMP

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_int(f, flags)
            else:
                write_byte(f, flags)

            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)
Ejemplo n.º 9
0
 def recv_results_prepared(self, f, protocol_version, user_type_map):
     self.query_id = read_binary_string(f)
     if ProtocolVersion.uses_prepared_metadata(protocol_version):
         self.result_metadata_id = read_binary_string(f)
     else:
         self.result_metadata_id = None
     self.recv_prepared_metadata(f, protocol_version, user_type_map)
Ejemplo n.º 10
0
 def recv_results_prepared(cls, f, protocol_version, user_type_map):
     query_id = read_binary_string(f)
     if ProtocolVersion.uses_prepared_metadata(protocol_version):
         result_metadata_id = read_binary_string(f)
     else:
         result_metadata_id = None
     bind_metadata, pk_indexes, result_metadata, _ = cls.recv_prepared_metadata(f, protocol_version, user_type_map)
     return query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id
Ejemplo n.º 11
0
 def recv_results_prepared(cls, f, protocol_version, user_type_map):
     query_id = read_binary_string(f)
     if ProtocolVersion.uses_prepared_metadata(protocol_version):
         result_metadata_id = read_binary_string(f)
     else:
         result_metadata_id = None
     bind_metadata, pk_indexes, result_metadata, _ = cls.recv_prepared_metadata(f, protocol_version, user_type_map)
     return query_id, bind_metadata, pk_indexes, result_metadata, result_metadata_id
Ejemplo n.º 12
0
    def send_body(self, f, protocol_version):
        write_byte(f, self.batch_type.value)
        write_short(f, len(self.queries))
        for prepared, string_or_query_id, params in self.queries:
            if not prepared:
                write_byte(f, 0)
                write_longstring(f, string_or_query_id)
            else:
                write_byte(f, 1)
                write_short(f, len(string_or_query_id))
                f.write(string_or_query_id)
            write_short(f, len(params))
            for param in params:
                write_value(f, param)

        write_consistency_level(f, self.consistency_level)
        if protocol_version >= 3:
            flags = 0
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.timestamp is not None:
                flags |= _PROTOCOL_TIMESTAMP
            if self.keyspace:
                if ProtocolVersion.uses_keyspace_flag(protocol_version):
                    flags |= _WITH_KEYSPACE_FLAG
                else:
                    raise UnsupportedOperation(
                        "Keyspaces may only be set on queries with protocol version "
                        "5 or higher. Consider setting Cluster.protocol_version to 5."
                    )

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_int(f, flags)
            else:
                write_byte(f, flags)

            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)

            if ProtocolVersion.uses_keyspace_flag(protocol_version):
                if self.keyspace is not None:
                    write_string(f, self.keyspace)
Ejemplo n.º 13
0
    def send_body(self, f, protocol_version):
        write_byte(f, self.batch_type.value)
        write_short(f, len(self.queries))
        for prepared, string_or_query_id, params in self.queries:
            if not prepared:
                write_byte(f, 0)
                write_longstring(f, string_or_query_id)
            else:
                write_byte(f, 1)
                write_short(f, len(string_or_query_id))
                f.write(string_or_query_id)
            write_short(f, len(params))
            for param in params:
                write_value(f, param)

        write_consistency_level(f, self.consistency_level)
        if protocol_version >= 3:
            flags = 0
            if self.serial_consistency_level:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            if self.timestamp is not None:
                flags |= _PROTOCOL_TIMESTAMP
            if self.keyspace:
                if ProtocolVersion.uses_keyspace_flag(protocol_version):
                    flags |= _WITH_KEYSPACE_FLAG
                else:
                    raise UnsupportedOperation(
                        "Keyspaces may only be set on queries with protocol version "
                        "5 or higher. Consider setting Cluster.protocol_version to 5.")

            if ProtocolVersion.uses_int_query_flags(protocol_version):
                write_int(f, flags)
            else:
                write_byte(f, flags)

            if self.serial_consistency_level:
                write_consistency_level(f, self.serial_consistency_level)
            if self.timestamp is not None:
                write_long(f, self.timestamp)

            if ProtocolVersion.uses_keyspace_flag(protocol_version):
                if self.keyspace is not None:
                    write_string(f, self.keyspace)
Ejemplo n.º 14
0
    def send_body(self, f, protocol_version):
        write_longstring(f, self.query)
        write_consistency_level(f, self.consistency_level)
        flags = 0x00
        if self._query_params is not None:
            flags |= _VALUES_FLAG  # also v2+, but we're only setting params internally right now

        if self.serial_consistency_level:
            if protocol_version >= 2:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            else:
                raise UnsupportedOperation(
                    "Serial consistency levels require the use of protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
                    "to support serial consistency levels.")

        if self.fetch_size:
            if protocol_version >= 2:
                flags |= _PAGE_SIZE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2."
                )

        if self.paging_state:
            if protocol_version >= 2:
                flags |= _WITH_PAGING_STATE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2."
                )

        if self.timestamp is not None:
            flags |= _PROTOCOL_TIMESTAMP

        if ProtocolVersion.uses_int_query_flags(protocol_version):
            write_uint(f, flags)
        else:
            write_byte(f, flags)

        if self._query_params is not None:
            write_short(f, len(self._query_params))
            for param in self._query_params:
                write_value(f, param)

        if self.fetch_size:
            write_int(f, self.fetch_size)
        if self.paging_state:
            write_longstring(f, self.paging_state)
        if self.serial_consistency_level:
            write_consistency_level(f, self.serial_consistency_level)
        if self.timestamp is not None:
            write_long(f, self.timestamp)
Ejemplo n.º 15
0
    def send_body(self, f, protocol_version):
        write_longstring(f, self.query)
        write_consistency_level(f, self.consistency_level)
        flags = 0x00
        if self._query_params is not None:
            flags |= _VALUES_FLAG  # also v2+, but we're only setting params internally right now

        if self.serial_consistency_level:
            if protocol_version >= 2:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            else:
                raise UnsupportedOperation(
                    "Serial consistency levels require the use of protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
                    "to support serial consistency levels.")

        if self.fetch_size:
            if protocol_version >= 2:
                flags |= _PAGE_SIZE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2.")

        if self.paging_state:
            if protocol_version >= 2:
                flags |= _WITH_PAGING_STATE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2.")

        if self.timestamp is not None:
            flags |= _PROTOCOL_TIMESTAMP

        if ProtocolVersion.uses_int_query_flags(protocol_version):
            write_uint(f, flags)
        else:
            write_byte(f, flags)

        if self._query_params is not None:
            write_short(f, len(self._query_params))
            for param in self._query_params:
                write_value(f, param)

        if self.fetch_size:
            write_int(f, self.fetch_size)
        if self.paging_state:
            write_longstring(f, self.paging_state)
        if self.serial_consistency_level:
            write_consistency_level(f, self.serial_consistency_level)
        if self.timestamp is not None:
            write_long(f, self.timestamp)
Ejemplo n.º 16
0
    def test_protocol_downgrade_test(self):
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V2)
        self.assertEqual(ProtocolVersion.DSE_V1, lower)
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V1)
        self.assertEqual(ProtocolVersion.V4, lower)
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V4)
        self.assertEqual(ProtocolVersion.V3, lower)
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V3)
        self.assertEqual(ProtocolVersion.V2, lower)
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V2)
        self.assertEqual(ProtocolVersion.V1, lower)
        lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V1)
        self.assertEqual(0, lower)

        self.assertTrue(
            ProtocolVersion.uses_error_code_map(ProtocolVersion.DSE_V1))
        self.assertTrue(
            ProtocolVersion.uses_int_query_flags(ProtocolVersion.DSE_V1))

        self.assertFalse(
            ProtocolVersion.uses_error_code_map(ProtocolVersion.V4))
        self.assertFalse(
            ProtocolVersion.uses_int_query_flags(ProtocolVersion.V4))
Ejemplo n.º 17
0
 def send_body(self, f, protocol_version):
     write_int(f, self.op_type)
     write_int(f, self.op_id)
     if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE:
         if self.next_pages <= 0:
             raise UnsupportedOperation(
                 "Continuous paging backpressure requires next_pages > 0")
         elif not ProtocolVersion.has_continuous_paging_next_pages(
                 protocol_version):
             raise UnsupportedOperation(
                 "Continuous paging backpressure may only be used with protocol version "
                 "ProtocolVersion.DSE_V2 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V2."
             )
         else:
             write_int(f, self.next_pages)
Ejemplo n.º 18
0
    def test_continuous_paging(self):
        """
        Test to check continuous paging throws an Exception if it's not supported and the correct valuesa
        are written to the buffer if the option is enabled.

        @since DSE 2.0b3 GRAPH 1.0b1
        @jira_ticket PYTHON-694
        @expected_result the values are correctly written

        @test_category connection
        """
        max_pages = 4
        max_pages_per_second = 3
        continuous_paging_options = ContinuousPagingOptions(
            max_pages=max_pages, max_pages_per_second=max_pages_per_second)
        message = QueryMessage(
            "a", 3, continuous_paging_options=continuous_paging_options)
        io = Mock()
        for version in [
                version for version in ProtocolVersion.SUPPORTED_VERSIONS
                if not ProtocolVersion.has_continuous_paging_support(version)
        ]:
            self.assertRaises(UnsupportedOperation, message.send_body, io,
                              version)

        io.reset_mock()
        message.send_body(io, ProtocolVersion.DSE_V1)

        # continuous paging adds two write calls to the buffer
        self.assertEqual(len(io.write.mock_calls), 6)
        # Check that the appropriate flag is set to True
        self.assertEqual(
            uint32_unpack(io.write.mock_calls[3][1][0])
            & _WITH_SERIAL_CONSISTENCY_FLAG, 0)
        self.assertEqual(
            uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG, 0)
        self.assertEqual(
            uint32_unpack(io.write.mock_calls[3][1][0])
            & _WITH_PAGING_STATE_FLAG, 0)
        self.assertEqual(
            uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG,
            _PAGING_OPTIONS_FLAG)

        # Test max_pages and max_pages_per_second are correctly written
        self.assertEqual(uint32_unpack(io.write.mock_calls[4][1][0]),
                         max_pages)
        self.assertEqual(uint32_unpack(io.write.mock_calls[5][1][0]),
                         max_pages_per_second)
Ejemplo n.º 19
0
    def test_prepare_flag_with_keyspace(self):
        message = PrepareMessage("a", keyspace='ks')
        io = Mock()

        for version in ProtocolVersion.SUPPORTED_VERSIONS:
            if ProtocolVersion.uses_keyspace_flag(version):
                message.send_body(io, version)
                self._check_calls(io, [
                    (b'\x00\x00\x00\x01',),
                    (b'a',),
                    (b'\x00\x00\x00\x01',),
                    (b'\x00\x02',),
                    (b'ks',),
                ])
            else:
                with self.assertRaises(UnsupportedOperation):
                    message.send_body(io, version)
            io.reset_mock()
Ejemplo n.º 20
0
    def test_prepare_flag_with_keyspace(self):
        message = PrepareMessage("a", keyspace='ks')
        io = Mock()

        for version in ProtocolVersion.SUPPORTED_VERSIONS:
            if ProtocolVersion.uses_keyspace_flag(version):
                message.send_body(io, version)
                self._check_calls(io, [
                    (b'\x00\x00\x00\x01', ),
                    (b'a', ),
                    (b'\x00\x00\x00\x01', ),
                    (b'\x00\x02', ),
                    (b'ks', ),
                ])
            else:
                with self.assertRaises(UnsupportedOperation):
                    message.send_body(io, version)
            io.reset_mock()
Ejemplo n.º 21
0
    def test_prepare_flag(self):
        """
        Test to check the prepare flag is properly set, This should only happen for V5 at the moment.

        @since 3.9
        @jira_ticket PYTHON-713
        @expected_result the values are correctly written

        @test_category connection
        """
        message = PrepareMessage("a")
        io = Mock()
        for version in ProtocolVersion.SUPPORTED_VERSIONS:
            message.send_body(io, version)
            if ProtocolVersion.uses_prepare_flags(version):
                self.assertEqual(len(io.write.mock_calls), 3)
            else:
                self.assertEqual(len(io.write.mock_calls), 2)
            io.reset_mock()
Ejemplo n.º 22
0
    def test_prepare_flag(self):
        """
        Test to check the prepare flag is properly set, This should only happen for V5 at the moment.

        @since 3.9
        @jira_ticket PYTHON-713
        @expected_result the values are correctly written

        @test_category connection
        """
        message = PrepareMessage("a")
        io = Mock()
        for version in ProtocolVersion.SUPPORTED_VERSIONS:
            message.send_body(io, version)
            if ProtocolVersion.uses_prepare_flags(version):
                self.assertEqual(len(io.write.mock_calls), 3)
            else:
                self.assertEqual(len(io.write.mock_calls), 2)
            io.reset_mock()
Ejemplo n.º 23
0
    def encode_message(cls, msg, stream_id, protocol_version, compressor,
                       allow_beta_protocol_version):
        """
        Encodes a message using the specified frame parameters, and compressor

        :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
        :param stream_id: protocol stream id for the frame header
        :param protocol_version: version for the frame header, and used encoding contents
        :param compressor: optional compression function to be used on the body
        """
        flags = 0
        body = io.BytesIO()
        if msg.custom_payload:
            if protocol_version < 4:
                raise UnsupportedOperation(
                    "Custom key/value payloads can only be used with protocol version 4 or higher"
                )
            flags |= CUSTOM_PAYLOAD_FLAG
            write_bytesmap(body, msg.custom_payload)
        msg.send_body(body, protocol_version)
        body = body.getvalue()

        # With checksumming, the compression is done at the segment frame encoding
        if (not ProtocolVersion.has_checksumming_support(protocol_version)
                and compressor and len(body) > 0):
            body = compressor(body)
            flags |= COMPRESSED_FLAG

        if msg.tracing:
            flags |= TRACING_FLAG

        if allow_beta_protocol_version:
            flags |= USE_BETA_FLAG

        buff = io.BytesIO()
        cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode,
                          len(body))
        buff.write(body)

        return buff.getvalue()
Ejemplo n.º 24
0
    def recv_error_info(f, protocol_version):
        consistency = read_consistency_level(f)
        received_responses = read_int(f)
        required_responses = read_int(f)

        if ProtocolVersion.uses_error_code_map(protocol_version):
            error_code_map = read_error_code_map(f)
            failures = len(error_code_map)
        else:
            error_code_map = None
            failures = read_int(f)

        write_type = WriteType.name_to_value[read_string(f)]

        return {
            'consistency': consistency,
            'received_responses': received_responses,
            'required_responses': required_responses,
            'failures': failures,
            'error_code_map': error_code_map,
            'write_type': write_type
        }
Ejemplo n.º 25
0
    def recv_error_info(f, protocol_version):
        consistency = read_consistency_level(f)
        received_responses = read_int(f)
        required_responses = read_int(f)

        if ProtocolVersion.uses_error_code_map(protocol_version):
            error_code_map = read_error_code_map(f)
            failures = len(error_code_map)
        else:
            error_code_map = None
            failures = read_int(f)

        data_retrieved = bool(read_byte(f))

        return {
            'consistency': consistency,
            'received_responses': received_responses,
            'required_responses': required_responses,
            'failures': failures,
            'error_code_map': error_code_map,
            'data_retrieved': data_retrieved
        }
Ejemplo n.º 26
0
    def recv_error_info(f, protocol_version):
        consistency = read_consistency_level(f)
        received_responses = read_int(f)
        required_responses = read_int(f)

        if ProtocolVersion.uses_error_code_map(protocol_version):
            error_code_map = read_error_code_map(f)
            failures = len(error_code_map)
        else:
            error_code_map = None
            failures = read_int(f)

        write_type = WriteType.name_to_value[read_string(f)]

        return {
            'consistency': consistency,
            'received_responses': received_responses,
            'required_responses': required_responses,
            'failures': failures,
            'error_code_map': error_code_map,
            'write_type': write_type
        }
Ejemplo n.º 27
0
    def recv_error_info(f, protocol_version):
        consistency = read_consistency_level(f)
        received_responses = read_int(f)
        required_responses = read_int(f)

        if ProtocolVersion.uses_error_code_map(protocol_version):
            error_code_map = read_error_code_map(f)
            failures = len(error_code_map)
        else:
            error_code_map = None
            failures = read_int(f)

        data_retrieved = bool(read_byte(f))

        return {
            'consistency': consistency,
            'received_responses': received_responses,
            'required_responses': required_responses,
            'failures': failures,
            'error_code_map': error_code_map,
            'data_retrieved': data_retrieved
        }
Ejemplo n.º 28
0
    def _perform_cql_statement(self, text, consistency_level, expected_exception, session=None):
        """
        Simple helper method to preform cql statements and check for expected exception
        @param text CQl statement to execute
        @param consistency_level Consistency level at which it is to be executed
        @param expected_exception Exception expected to be throw or none
        """
        if session is None:
            session = self.session
        statement = SimpleStatement(text)
        statement.consistency_level = consistency_level

        if expected_exception is None:
            self.execute_helper(session, statement)
        else:
            with self.assertRaises(expected_exception) as cm:
                self.execute_helper(session, statement)
            if ProtocolVersion.uses_error_code_map(PROTOCOL_VERSION):
                if isinstance(cm.exception, ReadFailure):
                    self.assertEqual(list(cm.exception.error_code_map.values())[0], 1)
                if isinstance(cm.exception, WriteFailure):
                    self.assertEqual(list(cm.exception.error_code_map.values())[0], 0)
Ejemplo n.º 29
0
 def _write_paging_options(self, f, paging_options, protocol_version):
     write_int(f, paging_options.max_pages)
     write_int(f, paging_options.max_pages_per_second)
     if ProtocolVersion.has_continuous_paging_next_pages(protocol_version):
         write_int(f, paging_options.max_queue_size)
Ejemplo n.º 30
0
    def _write_query_params(self, f, protocol_version):
        write_consistency_level(f, self.consistency_level)
        flags = 0x00
        if self.query_params is not None:
            flags |= _VALUES_FLAG  # also v2+, but we're only setting params internally right now

        if self.serial_consistency_level:
            if protocol_version >= 2:
                flags |= _WITH_SERIAL_CONSISTENCY_FLAG
            else:
                raise UnsupportedOperation(
                    "Serial consistency levels require the use of protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2 "
                    "to support serial consistency levels.")

        if self.fetch_size:
            if protocol_version >= 2:
                flags |= _PAGE_SIZE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2."
                )

        if self.paging_state:
            if protocol_version >= 2:
                flags |= _WITH_PAGING_STATE_FLAG
            else:
                raise UnsupportedOperation(
                    "Automatic query paging may only be used with protocol version "
                    "2 or higher. Consider setting Cluster.protocol_version to 2."
                )

        if self.timestamp is not None:
            flags |= _PROTOCOL_TIMESTAMP_FLAG

        if self.continuous_paging_options:
            if ProtocolVersion.has_continuous_paging_support(protocol_version):
                flags |= _PAGING_OPTIONS_FLAG
            else:
                raise UnsupportedOperation(
                    "Continuous paging may only be used with protocol version "
                    "ProtocolVersion.DSE_V1 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V1."
                )

        if self.keyspace is not None:
            if ProtocolVersion.uses_keyspace_flag(protocol_version):
                flags |= _WITH_KEYSPACE_FLAG
            else:
                raise UnsupportedOperation(
                    "Keyspaces may only be set on queries with protocol version "
                    "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version."
                )

        if ProtocolVersion.uses_int_query_flags(protocol_version):
            write_uint(f, flags)
        else:
            write_byte(f, flags)

        if self.query_params is not None:
            write_short(f, len(self.query_params))
            for param in self.query_params:
                write_value(f, param)
        if self.fetch_size:
            write_int(f, self.fetch_size)
        if self.paging_state:
            write_longstring(f, self.paging_state)
        if self.serial_consistency_level:
            write_consistency_level(f, self.serial_consistency_level)
        if self.timestamp is not None:
            write_long(f, self.timestamp)
        if self.keyspace is not None:
            write_string(f, self.keyspace)
        if self.continuous_paging_options:
            self._write_paging_options(f, self.continuous_paging_options,
                                       protocol_version)
Ejemplo n.º 31
0
 def send_body(self, f, protocol_version):
     write_longstring(f, self.query)
     if ProtocolVersion.uses_prepare_flags(protocol_version):
         # Write the flags byte; with 0 value for now, but this should change in PYTHON-678
         write_uint(f, 0)
Ejemplo n.º 32
0
 def send_body(self, f, protocol_version):
     write_string(f, self.query_id)
     if ProtocolVersion.uses_prepared_metadata(protocol_version):
         write_string(f, self.result_metadata_id)
     self._write_query_params(f, protocol_version)
Ejemplo n.º 33
0
 def send_body(self, f, protocol_version):
     write_longstring(f, self.query)
     if ProtocolVersion.uses_prepare_flags(protocol_version):
         # Write the flags byte; with 0 value for now, but this should change in PYTHON-678
         write_uint(f, 0)