def test_process_scpv1(self):

        # TestCommand.process should complain if supports_getinfo == False
        # We support dynamic configuration, not static

        # The exception line number may change, so we're using a regex match instead of a string match

        expected = re.compile(
            r'error_message=RuntimeError at ".+search_command\.py", line \d\d\d : Command test appears to be '
            r'statically configured for search command protocol version 1 and static configuration is unsupported by '
            r'splunklib.searchcommands. Please ensure that default/commands.conf contains this stanza:\n'
            r'\[test\]\n'
            r'filename = test.py\n'
            r'enableheader = true\n'
            r'outputheader = true\n'
            r'requires_srinfo = true\n'
            r'supports_getinfo = true\n'
            r'supports_multivalues = true\n'
            r'supports_rawargs = true')

        argv = ['test.py', 'not__GETINFO__or__EXECUTE__', 'option=value', 'fieldname']
        command = TestCommand()
        result = StringIO()

        self.assertRaises(SystemExit, command.process, argv, ofile=result)
        self.assertRegexpMatches(result.getvalue(), expected)

        # TestCommand.process should return configuration settings on Getinfo probe

        argv = ['test.py', '__GETINFO__', 'required_option_1=value', 'required_option_2=value']
        command = TestCommand()
        ifile = StringIO('\n')
        result = StringIO()

        self.assertEqual(str(command.configuration), '')

        if six.PY2:
            expected = ("[(u'clear_required_fields', None, [1]), (u'distributed', None, [2]), (u'generates_timeorder', None, [1]), "
            "(u'generating', None, [1, 2]), (u'maxinputs', None, [2]), (u'overrides_timeorder', None, [1]), "
            "(u'required_fields', None, [1, 2]), (u'requires_preop', None, [1]), (u'retainsevents', None, [1]), "
            "(u'run_in_preview', None, [2]), (u'streaming', None, [1]), (u'streaming_preop', None, [1, 2]), "
            "(u'type', None, [2])]")
        else:
            expected = ("[('clear_required_fields', None, [1]), ('distributed', None, [2]), ('generates_timeorder', None, [1]), "
            "('generating', None, [1, 2]), ('maxinputs', None, [2]), ('overrides_timeorder', None, [1]), "
            "('required_fields', None, [1, 2]), ('requires_preop', None, [1]), ('retainsevents', None, [1]), "
            "('run_in_preview', None, [2]), ('streaming', None, [1]), ('streaming_preop', None, [1, 2]), "
            "('type', None, [2])]")

        self.assertEqual(
            repr(command.configuration), expected)

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('{0}: {1}: {2}\n'.format(type(error).__name__, error, result.getvalue()))

        self.assertEqual('\r\n\r\n\r\n', result.getvalue())  # No message header and no configuration settings

        ifile = StringIO('\n')
        result = StringIO()

        # We might also put this sort of code into our SearchCommand.prepare override ...

        configuration = command.configuration

        # SCP v1/v2 configuration settings
        configuration.generating = True
        configuration.required_fields = ['foo', 'bar']
        configuration.streaming_preop = 'some streaming command'

        # SCP v1 configuration settings
        configuration.clear_required_fields = True
        configuration.generates_timeorder = True
        configuration.overrides_timeorder = True
        configuration.requires_preop = True
        configuration.retainsevents = True
        configuration.streaming = True

        # SCP v2 configuration settings (SCP v1 requires that maxinputs and run_in_preview are set in commands.conf)
        configuration.distributed = True
        configuration.maxinputs = 50000
        configuration.run_in_preview = True
        configuration.type = 'streaming'

        if six.PY2:
            expected = ('clear_required_fields="True", generates_timeorder="True", generating="True", overrides_timeorder="True", '
                'required_fields="[u\'foo\', u\'bar\']", requires_preop="True", retainsevents="True", streaming="True", '
                'streaming_preop="some streaming command"')
        else:
            expected = ('clear_required_fields="True", generates_timeorder="True", generating="True", overrides_timeorder="True", '
                        'required_fields="[\'foo\', \'bar\']", requires_preop="True", retainsevents="True", streaming="True", '
                        'streaming_preop="some streaming command"')
        self.assertEqual(str(command.configuration), expected)

        if six.PY2:
            expected = ("[(u'clear_required_fields', True, [1]), (u'distributed', True, [2]), (u'generates_timeorder', True, [1]), "
            "(u'generating', True, [1, 2]), (u'maxinputs', 50000, [2]), (u'overrides_timeorder', True, [1]), "
            "(u'required_fields', [u'foo', u'bar'], [1, 2]), (u'requires_preop', True, [1]), "
            "(u'retainsevents', True, [1]), (u'run_in_preview', True, [2]), (u'streaming', True, [1]), "
            "(u'streaming_preop', u'some streaming command', [1, 2]), (u'type', u'streaming', [2])]")
        else:
            expected = ("[('clear_required_fields', True, [1]), ('distributed', True, [2]), ('generates_timeorder', True, [1]), "
            "('generating', True, [1, 2]), ('maxinputs', 50000, [2]), ('overrides_timeorder', True, [1]), "
            "('required_fields', ['foo', 'bar'], [1, 2]), ('requires_preop', True, [1]), "
            "('retainsevents', True, [1]), ('run_in_preview', True, [2]), ('streaming', True, [1]), "
            "('streaming_preop', 'some streaming command', [1, 2]), ('type', 'streaming', [2])]")

        self.assertEqual(
            repr(command.configuration), expected)

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('{0}: {1}: {2}\n'.format(type(error).__name__, error, result.getvalue()))

        result.seek(0)
        reader = csv.reader(result)
        self.assertEqual([], next(reader))
        observed = dict(izip(next(reader), next(reader)))
        self.assertRaises(StopIteration, lambda: next(reader))

        expected = {
            'clear_required_fields': '1',                '__mv_clear_required_fields': '',
            'generating': '1',                           '__mv_generating': '',
            'generates_timeorder': '1',                  '__mv_generates_timeorder': '',
            'overrides_timeorder': '1',                  '__mv_overrides_timeorder': '',
            'requires_preop': '1',                       '__mv_requires_preop': '',
            'required_fields': 'foo,bar',                '__mv_required_fields': '',
            'retainsevents': '1',                        '__mv_retainsevents': '',
            'streaming': '1',                            '__mv_streaming': '',
            'streaming_preop': 'some streaming command', '__mv_streaming_preop': '',
        }

        self.assertDictEqual(expected, observed)  # No message header and no configuration settings

        for action in '__GETINFO__', '__EXECUTE__':

            # TestCommand.process should produce an error record on parser errors

            argv = [
                'test.py', action, 'required_option_1=value', 'required_option_2=value', 'undefined_option=value',
                'fieldname_1', 'fieldname_2']

            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit, command.process, argv, ifile, ofile=result)
            self.assertTrue(
                'error_message=Unrecognized test command option: undefined_option="value"\r\n\r\n',
                result.getvalue())

            # TestCommand.process should produce an error record when required options are missing

            argv = ['test.py', action, 'required_option_2=value', 'fieldname_1']
            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit, command.process, argv, ifile, ofile=result)

            self.assertTrue(
                'error_message=A value for test command option required_option_1 is required\r\n\r\n',
                result.getvalue())

            argv = ['test.py', action, 'fieldname_1']
            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit, command.process, argv, ifile, ofile=result)

            self.assertTrue(
                'error_message=Values for these test command options are required: required_option_1, required_option_2'
                '\r\n\r\n',
                result.getvalue())

        # TestStreamingCommand.process should exit on processing exceptions

        ifile = StringIO('\naction\r\nraise_error\r\n')
        argv = ['test.py', '__EXECUTE__']
        command = TestStreamingCommand()
        result = StringIO()

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except SystemExit as error:
            self.assertNotEqual(error.code, 0)
            self.assertRegexpMatches(
                result.getvalue(),
                r'^error_message=RuntimeError at ".+", line \d+ : Testing\r\n\r\n$')
        except BaseException as error:
            self.fail('Expected SystemExit, but caught {}: {}'.format(type(error).__name__, error))
        else:
            self.fail('Expected SystemExit, but no exception was raised')

        # Command.process should provide access to search results info
        info_path = os.path.join(
            self._package_directory, 'recordings', 'scpv1', 'Splunk-6.3', 'countmatches.execute.dispatch_dir',
            'externSearchResultsInfo.csv')

        ifile = StringIO('infoPath:' + info_path + '\n\naction\r\nget_search_results_info\r\n')
        argv = ['test.py', '__EXECUTE__']
        command = TestStreamingCommand()
        result = StringIO()

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('Expected no exception, but caught {}: {}'.format(type(error).__name__, error))
        else:
            self.assertRegexpMatches(
                result.getvalue(),
                r'^\r\n'
                r'('
                r'data,__mv_data,_serial,__mv__serial\r\n'
                r'"\{.*u\'is_summary_index\': 0, .+\}",,0,'
                r'|'
                r'_serial,__mv__serial,data,__mv_data\r\n'
                r'0,,"\{.*u\'is_summary_index\': 0, .+\}",'
                r')'
                r'\r\n$'
            )

        # TestStreamingCommand.process should provide access to a service object when search results info is available

        self.assertIsInstance(command.service, Service)

        self.assertEqual(command.service.authority,
                         command.search_results_info.splunkd_uri)

        self.assertEqual(command.service.scheme,
                         command.search_results_info.splunkd_protocol)

        self.assertEqual(command.service.port,
                         command.search_results_info.splunkd_port)

        self.assertEqual(command.service.token,
                         command.search_results_info.auth_token)

        self.assertEqual(command.service.namespace.app,
                         command.search_results_info.ppc_app)

        self.assertEqual(command.service.namespace.owner,
                         None)
        self.assertEqual(command.service.namespace.sharing,
                         None)

        # Command.process should not provide access to search results info or a service object when the 'infoPath'
        # input header is unavailable

        ifile = StringIO('\naction\r\nget_search_results_info')
        argv = ['teststreaming.py', '__EXECUTE__']
        command = TestStreamingCommand()

        # noinspection PyTypeChecker
        command.process(argv, ifile, ofile=result)

        self.assertIsNone(command.search_results_info)
        self.assertIsNone(command.service)

        return
    def test_process_scpv1(self):

        # TestCommand.process should complain if supports_getinfo == False
        # We support dynamic configuration, not static

        # The exception line number may change, so we're using a regex match instead of a string match

        expected = re.compile(
            r'error_message=RuntimeError at ".+search_command\.py", line \d\d\d : Command test appears to be '
            r'statically configured for search command protocol version 1 and static configuration is unsupported by '
            r'splunklib.searchcommands. Please ensure that default/commands.conf contains this stanza:\n'
            r'\[test\]\n'
            r'filename = test.py\n'
            r'enableheader = true\n'
            r'outputheader = true\n'
            r'requires_srinfo = true\n'
            r'supports_getinfo = true\n'
            r'supports_multivalues = true\n'
            r'supports_rawargs = true')

        argv = [
            'test.py', 'not__GETINFO__or__EXECUTE__', 'option=value',
            'fieldname'
        ]
        command = TestCommand()
        result = StringIO()

        self.assertRaises(SystemExit, command.process, argv, ofile=result)
        self.assertRegexpMatches(result.getvalue(), expected)

        # TestCommand.process should return configuration settings on Getinfo probe

        argv = [
            'test.py', '__GETINFO__', 'required_option_1=value',
            'required_option_2=value'
        ]
        command = TestCommand()
        ifile = StringIO('\n')
        result = StringIO()

        self.assertEqual(str(command.configuration), '')

        if six.PY2:
            expected = (
                "[(u'clear_required_fields', None, [1]), (u'distributed', None, [2]), (u'generates_timeorder', None, [1]), "
                "(u'generating', None, [1, 2]), (u'maxinputs', None, [2]), (u'overrides_timeorder', None, [1]), "
                "(u'required_fields', None, [1, 2]), (u'requires_preop', None, [1]), (u'retainsevents', None, [1]), "
                "(u'run_in_preview', None, [2]), (u'streaming', None, [1]), (u'streaming_preop', None, [1, 2]), "
                "(u'type', None, [2])]")
        else:
            expected = (
                "[('clear_required_fields', None, [1]), ('distributed', None, [2]), ('generates_timeorder', None, [1]), "
                "('generating', None, [1, 2]), ('maxinputs', None, [2]), ('overrides_timeorder', None, [1]), "
                "('required_fields', None, [1, 2]), ('requires_preop', None, [1]), ('retainsevents', None, [1]), "
                "('run_in_preview', None, [2]), ('streaming', None, [1]), ('streaming_preop', None, [1, 2]), "
                "('type', None, [2])]")

        self.assertEqual(repr(command.configuration), expected)

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('{0}: {1}: {2}\n'.format(
                type(error).__name__, error, result.getvalue()))

        self.assertEqual('\r\n\r\n\r\n', result.getvalue()
                         )  # No message header and no configuration settings

        ifile = StringIO('\n')
        result = StringIO()

        # We might also put this sort of code into our SearchCommand.prepare override ...

        configuration = command.configuration

        # SCP v1/v2 configuration settings
        configuration.generating = True
        configuration.required_fields = ['foo', 'bar']
        configuration.streaming_preop = 'some streaming command'

        # SCP v1 configuration settings
        configuration.clear_required_fields = True
        configuration.generates_timeorder = True
        configuration.overrides_timeorder = True
        configuration.requires_preop = True
        configuration.retainsevents = True
        configuration.streaming = True

        # SCP v2 configuration settings (SCP v1 requires that maxinputs and run_in_preview are set in commands.conf)
        configuration.distributed = True
        configuration.maxinputs = 50000
        configuration.run_in_preview = True
        configuration.type = 'streaming'

        if six.PY2:
            expected = (
                'clear_required_fields="True", generates_timeorder="True", generating="True", overrides_timeorder="True", '
                'required_fields="[u\'foo\', u\'bar\']", requires_preop="True", retainsevents="True", streaming="True", '
                'streaming_preop="some streaming command"')
        else:
            expected = (
                'clear_required_fields="True", generates_timeorder="True", generating="True", overrides_timeorder="True", '
                'required_fields="[\'foo\', \'bar\']", requires_preop="True", retainsevents="True", streaming="True", '
                'streaming_preop="some streaming command"')
        self.assertEqual(str(command.configuration), expected)

        if six.PY2:
            expected = (
                "[(u'clear_required_fields', True, [1]), (u'distributed', True, [2]), (u'generates_timeorder', True, [1]), "
                "(u'generating', True, [1, 2]), (u'maxinputs', 50000, [2]), (u'overrides_timeorder', True, [1]), "
                "(u'required_fields', [u'foo', u'bar'], [1, 2]), (u'requires_preop', True, [1]), "
                "(u'retainsevents', True, [1]), (u'run_in_preview', True, [2]), (u'streaming', True, [1]), "
                "(u'streaming_preop', u'some streaming command', [1, 2]), (u'type', u'streaming', [2])]"
            )
        else:
            expected = (
                "[('clear_required_fields', True, [1]), ('distributed', True, [2]), ('generates_timeorder', True, [1]), "
                "('generating', True, [1, 2]), ('maxinputs', 50000, [2]), ('overrides_timeorder', True, [1]), "
                "('required_fields', ['foo', 'bar'], [1, 2]), ('requires_preop', True, [1]), "
                "('retainsevents', True, [1]), ('run_in_preview', True, [2]), ('streaming', True, [1]), "
                "('streaming_preop', 'some streaming command', [1, 2]), ('type', 'streaming', [2])]"
            )

        self.assertEqual(repr(command.configuration), expected)

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('{0}: {1}: {2}\n'.format(
                type(error).__name__, error, result.getvalue()))

        result.seek(0)
        reader = csv.reader(result)
        self.assertEqual([], next(reader))
        observed = dict(izip(next(reader), next(reader)))
        self.assertRaises(StopIteration, lambda: next(reader))

        expected = {
            'clear_required_fields': '1',
            '__mv_clear_required_fields': '',
            'generating': '1',
            '__mv_generating': '',
            'generates_timeorder': '1',
            '__mv_generates_timeorder': '',
            'overrides_timeorder': '1',
            '__mv_overrides_timeorder': '',
            'requires_preop': '1',
            '__mv_requires_preop': '',
            'required_fields': 'foo,bar',
            '__mv_required_fields': '',
            'retainsevents': '1',
            '__mv_retainsevents': '',
            'streaming': '1',
            '__mv_streaming': '',
            'streaming_preop': 'some streaming command',
            '__mv_streaming_preop': '',
        }

        self.assertDictEqual(
            expected,
            observed)  # No message header and no configuration settings

        for action in '__GETINFO__', '__EXECUTE__':

            # TestCommand.process should produce an error record on parser errors

            argv = [
                'test.py', action, 'required_option_1=value',
                'required_option_2=value', 'undefined_option=value',
                'fieldname_1', 'fieldname_2'
            ]

            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit,
                              command.process,
                              argv,
                              ifile,
                              ofile=result)
            self.assertTrue(
                'error_message=Unrecognized test command option: undefined_option="value"\r\n\r\n',
                result.getvalue())

            # TestCommand.process should produce an error record when required options are missing

            argv = [
                'test.py', action, 'required_option_2=value', 'fieldname_1'
            ]
            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit,
                              command.process,
                              argv,
                              ifile,
                              ofile=result)

            self.assertTrue(
                'error_message=A value for test command option required_option_1 is required\r\n\r\n',
                result.getvalue())

            argv = ['test.py', action, 'fieldname_1']
            command = TestCommand()
            ifile = StringIO('\n')
            result = StringIO()

            self.assertRaises(SystemExit,
                              command.process,
                              argv,
                              ifile,
                              ofile=result)

            self.assertTrue(
                'error_message=Values for these test command options are required: required_option_1, required_option_2'
                '\r\n\r\n', result.getvalue())

        # TestStreamingCommand.process should exit on processing exceptions

        ifile = StringIO('\naction\r\nraise_error\r\n')
        argv = ['test.py', '__EXECUTE__']
        command = TestStreamingCommand()
        result = StringIO()

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except SystemExit as error:
            self.assertNotEqual(error.code, 0)
            self.assertRegexpMatches(
                result.getvalue(),
                r'^error_message=RuntimeError at ".+", line \d+ : Testing\r\n\r\n$'
            )
        except BaseException as error:
            self.fail('Expected SystemExit, but caught {}: {}'.format(
                type(error).__name__, error))
        else:
            self.fail('Expected SystemExit, but no exception was raised')

        # Command.process should provide access to search results info
        info_path = os.path.join(self._package_directory, 'recordings',
                                 'scpv1', 'Splunk-6.3',
                                 'countmatches.execute.dispatch_dir',
                                 'externSearchResultsInfo.csv')

        ifile = StringIO('infoPath:' + info_path +
                         '\n\naction\r\nget_search_results_info\r\n')
        argv = ['test.py', '__EXECUTE__']
        command = TestStreamingCommand()
        result = StringIO()

        try:
            # noinspection PyTypeChecker
            command.process(argv, ifile, ofile=result)
        except BaseException as error:
            self.fail('Expected no exception, but caught {}: {}'.format(
                type(error).__name__, error))
        else:
            self.assertRegexpMatches(
                result.getvalue(), r'^\r\n'
                r'('
                r'data,__mv_data,_serial,__mv__serial\r\n'
                r'"\{.*u\'is_summary_index\': 0, .+\}",,0,'
                r'|'
                r'_serial,__mv__serial,data,__mv_data\r\n'
                r'0,,"\{.*u\'is_summary_index\': 0, .+\}",'
                r')'
                r'\r\n$')

        # TestStreamingCommand.process should provide access to a service object when search results info is available

        self.assertIsInstance(command.service, Service)

        self.assertEqual(command.service.authority,
                         command.search_results_info.splunkd_uri)

        self.assertEqual(command.service.scheme,
                         command.search_results_info.splunkd_protocol)

        self.assertEqual(command.service.port,
                         command.search_results_info.splunkd_port)

        self.assertEqual(command.service.token,
                         command.search_results_info.auth_token)

        self.assertEqual(command.service.namespace.app,
                         command.search_results_info.ppc_app)

        self.assertEqual(command.service.namespace.owner, None)
        self.assertEqual(command.service.namespace.sharing, None)

        # Command.process should not provide access to search results info or a service object when the 'infoPath'
        # input header is unavailable

        ifile = StringIO('\naction\r\nget_search_results_info')
        argv = ['teststreaming.py', '__EXECUTE__']
        command = TestStreamingCommand()

        # noinspection PyTypeChecker
        command.process(argv, ifile, ofile=result)

        self.assertIsNone(command.search_results_info)
        self.assertIsNone(command.service)

        return
示例#3
0
class RecordWriter(object):
    def __init__(self, ofile, maxresultrows=None):
        self._maxresultrows = 50000 if maxresultrows is None else maxresultrows

        self._ofile = set_binary_mode(ofile)
        self._fieldnames = None
        self._buffer = StringIO()

        self._writer = csv.writer(self._buffer, dialect=CsvDialect)
        self._writerow = self._writer.writerow
        self._finished = False
        self._flushed = False

        self._inspector = OrderedDict()
        self._chunk_count = 0
        self._pending_record_count = 0
        self._committed_record_count = 0

    @property
    def is_flushed(self):
        return self._flushed

    @is_flushed.setter
    def is_flushed(self, value):
        self._flushed = True if value else False

    @property
    def ofile(self):
        return self._ofile

    @ofile.setter
    def ofile(self, value):
        self._ofile = set_binary_mode(value)

    @property
    def pending_record_count(self):
        return self._pending_record_count

    @property
    def _record_count(self):
        warnings.warn(
            "_record_count will be deprecated soon. Use pending_record_count instead.",
            PendingDeprecationWarning)
        return self.pending_record_count

    @property
    def committed_record_count(self):
        return self._committed_record_count

    @property
    def _total_record_count(self):
        warnings.warn(
            "_total_record_count will be deprecated soon. Use committed_record_count instead.",
            PendingDeprecationWarning)
        return self.committed_record_count

    def write(self, data):
        bytes_type = bytes if sys.version_info >= (3, 0) else str
        if not isinstance(data, bytes_type):
            data = data.encode('utf-8')
        self.ofile.write(data)

    def flush(self, finished=None, partial=None):
        assert finished is None or isinstance(finished, bool)
        assert partial is None or isinstance(partial, bool)
        assert not (finished is None and partial is None)
        assert finished is None or partial is None
        self._ensure_validity()

    def write_message(self, message_type, message_text, *args, **kwargs):
        self._ensure_validity()
        self._inspector.setdefault('messages', []).append(
            (message_type, message_text.format(*args, **kwargs)))

    def write_record(self, record):
        self._ensure_validity()
        self._write_record(record)

    def write_records(self, records):
        self._ensure_validity()
        write_record = self._write_record
        for record in records:
            write_record(record)

    def _clear(self):
        self._buffer.seek(0)
        self._buffer.truncate()
        self._inspector.clear()
        self._pending_record_count = 0

    def _ensure_validity(self):
        if self._finished is True:
            assert self._record_count == 0 and len(self._inspector) == 0
            raise RuntimeError('I/O operation on closed record writer')

    def _write_record(self, record):

        fieldnames = self._fieldnames

        if fieldnames is None:
            self._fieldnames = fieldnames = list(record.keys())
            value_list = imap(lambda fn: (str(fn), str('__mv_') + str(fn)),
                              fieldnames)
            self._writerow(list(chain.from_iterable(value_list)))

        get_value = record.get
        values = []

        for fieldname in fieldnames:
            value = get_value(fieldname, None)

            if value is None:
                values += (None, None)
                continue

            value_t = type(value)

            if issubclass(value_t, (list, tuple)):

                if len(value) == 0:
                    values += (None, None)
                    continue

                if len(value) > 1:
                    value_list = value
                    sv = ''
                    mv = '$'

                    for value in value_list:

                        if value is None:
                            sv += '\n'
                            mv += '$;$'
                            continue

                        value_t = type(value)

                        if value_t is not bytes:

                            if value_t is bool:
                                value = str(value.real)
                            elif value_t is six.text_type:
                                value = value
                            elif isinstance(
                                    value, six.integer_types
                            ) or value_t is float or value_t is complex:
                                value = str(value)
                            elif issubclass(value_t, (dict, list, tuple)):
                                value = str(''.join(
                                    RecordWriter._iterencode_json(value, 0)))
                            else:
                                value = repr(value).encode(
                                    'utf-8', errors='backslashreplace')

                        sv += value + '\n'
                        mv += value.replace('$', '$$') + '$;$'

                    values += (sv[:-1], mv[:-2])
                    continue

                value = value[0]
                value_t = type(value)

            if value_t is bool:
                values += (str(value.real), None)
                continue

            if value_t is bytes:
                values += (value, None)
                continue

            if value_t is six.text_type:
                if six.PY2:
                    value = value.encode('utf-8')
                values += (value, None)
                continue

            if isinstance(value, six.integer_types
                          ) or value_t is float or value_t is complex:
                values += (str(value), None)
                continue

            if issubclass(value_t, dict):
                values += (str(''.join(RecordWriter._iterencode_json(value,
                                                                     0))),
                           None)
                continue

            values += (repr(value), None)

        self._writerow(values)
        self._pending_record_count += 1

        if self.pending_record_count >= self._maxresultrows:
            self.flush(partial=True)

    try:
        # noinspection PyUnresolvedReferences
        from _json import make_encoder
    except ImportError:
        # We may be running under PyPy 2.5 which does not include the _json module
        _iterencode_json = JSONEncoder(separators=(',', ':')).iterencode
    else:
        # Creating _iterencode_json this way yields a two-fold performance improvement on Python 2.7.9 and 2.7.10
        from json.encoder import encode_basestring_ascii

        @staticmethod
        def _default(o):
            raise TypeError(repr(o) + ' is not JSON serializable')

        _iterencode_json = make_encoder(
            {},  # markers (for detecting circular references)
            _default,  # object_encoder
            encode_basestring_ascii,  # string_encoder
            None,  # indent
            ':',
            ',',  # separators
            False,  # sort_keys
            False,  # skip_keys
            True  # allow_nan
        )

        del make_encoder
示例#4
0
class RecordWriter(object):

    def __init__(self, ofile, maxresultrows=None):
        self._maxresultrows = 50000 if maxresultrows is None else maxresultrows

        self._ofile = ofile
        self._fieldnames = None
        self._buffer = StringIO()

        self._writer = csv.writer(self._buffer, dialect=CsvDialect)
        self._writerow = self._writer.writerow
        self._finished = False
        self._flushed = False

        self._inspector = OrderedDict()
        self._chunk_count = 0
        self._record_count = 0
        self._total_record_count = 0

    @property
    def is_flushed(self):
        return self._flushed

    @is_flushed.setter
    def is_flushed(self, value):
        self._flushed = True if value else False

    @property
    def ofile(self):
        return self._ofile

    @ofile.setter
    def ofile(self, value):
        self._ofile = value

    def flush(self, finished=None, partial=None):
        assert finished is None or isinstance(finished, bool)
        assert partial is None or isinstance(partial, bool)
        assert not (finished is None and partial is None)
        assert finished is None or partial is None
        self._ensure_validity()

    def write_message(self, message_type, message_text, *args, **kwargs):
        self._ensure_validity()
        self._inspector.setdefault('messages', []).append((message_type, message_text.format(*args, **kwargs)))

    def write_record(self, record):
        self._ensure_validity()
        self._write_record(record)

    def write_records(self, records):
        self._ensure_validity()
        write_record = self._write_record
        for record in records:
            write_record(record)

    def _clear(self):
        self._buffer.seek(0)
        self._buffer.truncate()
        self._inspector.clear()
        self._record_count = 0
        self._flushed = False

    def _ensure_validity(self):
        if self._finished is True:
            assert self._record_count == 0 and len(self._inspector) == 0
            raise RuntimeError('I/O operation on closed record writer')

    def _write_record(self, record):

        fieldnames = self._fieldnames

        if fieldnames is None:
            self._fieldnames = fieldnames = list(record.keys())
            value_list = imap(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames)
            self._writerow(list(chain.from_iterable(value_list)))

        get_value = record.get
        values = []

        for fieldname in fieldnames:
            value = get_value(fieldname, None)

            if value is None:
                values += (None, None)
                continue

            value_t = type(value)

            if issubclass(value_t, (list, tuple)):

                if len(value) == 0:
                    values += (None, None)
                    continue

                if len(value) > 1:
                    value_list = value
                    sv = ''
                    mv = '$'

                    for value in value_list:

                        if value is None:
                            sv += '\n'
                            mv += '$;$'
                            continue

                        value_t = type(value)

                        if value_t is not bytes:

                            if value_t is bool:
                                value = str(value.real)
                            elif value_t is six.text_type:
                                value = value
                            elif value_t is int or value_t is int or value_t is float or value_t is complex:
                                value = str(value)
                            elif issubclass(value_t, (dict, list, tuple)):
                                value = str(''.join(RecordWriter._iterencode_json(value, 0)))
                            else:
                                value = repr(value).encode('utf-8', errors='backslashreplace')

                        sv += value + '\n'
                        mv += value.replace('$', '$$') + '$;$'

                    values += (sv[:-1], mv[:-2])
                    continue

                value = value[0]
                value_t = type(value)

            if value_t is bool:
                values += (str(value.real), None)
                continue

            if value_t is bytes:
                values += (value, None)
                continue

            if value_t is six.text_type:
                if six.PY2:
                    value = value.encode('utf-8')
                values += (value, None)
                continue

            if value_t is int or value_t is int or value_t is float or value_t is complex:
                values += (str(value), None)
                continue

            if issubclass(value_t, dict):
                values += (str(''.join(RecordWriter._iterencode_json(value, 0))), None)
                continue

            values += (repr(value), None)

        self._writerow(values)
        self._record_count += 1

        if self._record_count >= self._maxresultrows:
            self.flush(partial=True)

    try:
        # noinspection PyUnresolvedReferences
        from _json import make_encoder
    except ImportError:
        # We may be running under PyPy 2.5 which does not include the _json module
        _iterencode_json = JSONEncoder(separators=(',', ':')).iterencode
    else:
        # Creating _iterencode_json this way yields a two-fold performance improvement on Python 2.7.9 and 2.7.10
        from json.encoder import encode_basestring_ascii

        @staticmethod
        def _default(o):
            raise TypeError(repr(o) + ' is not JSON serializable')

        _iterencode_json = make_encoder(
            {},                       # markers (for detecting circular references)
            _default,                 # object_encoder
            encode_basestring_ascii,  # string_encoder
            None,                     # indent
            ':', ',',                 # separators
            False,                    # sort_keys
            False,                    # skip_keys
            True                      # allow_nan
        )

        del make_encoder