コード例 #1
0
ファイル: test_writer.py プロジェクト: vitgou/warcio
    def test_request_response_concur(self, is_gzip, builder_factory):
        writer = BufferWARCWriter(gzip=is_gzip)
        builder = builder_factory(writer, builder_cls=RecordBuilder)

        resp = sample_response(builder)

        req = sample_request(builder)

        # test explicitly calling ensure_digest with block digest enabled on a record
        writer.ensure_digest(resp, block=True, payload=True)

        writer.write_request_response_pair(req, resp)

        stream = writer.get_stream()

        reader = ArchiveIterator(stream)
        resp, req = list(reader)

        resp_id = resp.rec_headers.get_header('WARC-Record-ID')
        req_id = req.rec_headers.get_header('WARC-Record-ID')

        assert resp_id != req_id
        assert resp_id == req.rec_headers.get_header('WARC-Concurrent-To')
コード例 #2
0
ファイル: main.py プロジェクト: machawk1/cdxj-indexer
class CDXJIndexer(Indexer):
    field_names = {
        "warc-target-uri": "url",
        "http:status": "status",
        "warc-payload-digest": "digest",
        "req.http:referer": "referrer",
        "req.http:method": "method",
    }

    inv_field_names = {k: v for v, k in field_names.items()}

    DEFAULT_FIELDS = [
        "warc-target-uri",
        "mime",
        "http:status",
        "warc-payload-digest",
        "length",
        "offset",
        "filename",
    ]

    DEFAULT_RECORDS = ["response", "revisit", "resource", "metadata"]

    ALLOWED_EXT = (".arc", ".arc.gz", ".warc", ".warc.gz")

    RE_SPACE = re.compile(r'[;\s]')

    def __init__(self,
                 output,
                 inputs,
                 post_append=False,
                 sort=False,
                 compress=None,
                 lines=300,
                 data_out_name=None,
                 filename=None,
                 fields=None,
                 replace_fields=None,
                 records=None,
                 verify_http=False,
                 dir_root=None,
                 **kwargs):

        if isinstance(inputs, str) or hasattr(inputs, "read"):
            inputs = [inputs]

        inputs = iter_file_or_dir(inputs)

        fields = self._parse_fields(fields, replace_fields)

        super(CDXJIndexer, self).__init__(fields,
                                          inputs,
                                          output,
                                          verify_http=verify_http)
        self.writer = None

        self.curr_filename = None
        self.force_filename = filename
        self.post_append = post_append
        self.dir_root = dir_root

        self.num_lines = lines
        self.sort = sort
        self.compress = compress
        self.data_out_name = data_out_name

        self.include_records = records
        if self.include_records == "all":
            self.include_records = None
        elif self.include_records:
            self.include_records = self.include_records.split(",")
        else:
            self.include_records = self.DEFAULT_RECORDS

        self.collect_records = self.post_append or any(
            field.startswith("req.http:") for field in self.fields)
        self.record_parse = True

    def _parse_fields(self, fields=None, replace_fields=None):
        add_fields = replace_fields
        if add_fields:
            fields = []
        else:
            add_fields = fields
            fields = copy(self.DEFAULT_FIELDS)

        if add_fields:
            add_fields = add_fields.split(",")
            for field in add_fields:
                fields.append(self.inv_field_names.get(field, field))

        return fields

    def get_field(self, record, name, it, filename):
        if name == "mime":
            if record.rec_type == "revisit":
                return "warc/revisit"
            elif record.rec_type in ("response", "request"):
                name = "http:content-type"
            else:
                name = "content-type"

            value = super(CDXJIndexer, self).get_field(record, name, it,
                                                       filename)
            if value:
                value = self.RE_SPACE.split(value, 1)[0].strip()

            return value

        if name == "filename":
            return self.curr_filename

        if self.collect_records:
            if name == "offset":
                return str(record.file_offset)
            elif name == "length":
                return str(record.file_length)
            elif name.startswith("req.http:"):
                value = self._get_req_field(name, record)
                if value:
                    return value

        value = super(CDXJIndexer, self).get_field(record, name, it, filename)

        if name == "warc-payload-digest":
            value = self._get_digest(record, name)

        return value

    def _get_req_field(self, name, record):
        if hasattr(record, "req"):
            req = record.req
        elif record.rec_type == "request":
            req = record
        else:
            return None

        if name == "req.http:method":
            return req.http_headers.protocol
        else:
            return req.http_headers.get_header(name[9:])

    def process_all(self):
        data_out = None

        with open_or_default(self.output, "wt", sys.stdout) as fh:
            if self.compress:
                if isinstance(self.compress, str):
                    data_out = open(self.compress, "wb")
                    if os.path.splitext(self.compress)[1] == "":
                        self.compress += ".cdxj.gz"

                    fh = CompressedWriter(
                        fh,
                        data_out=data_out,
                        data_out_name=self.compress,
                        num_lines=self.num_lines,
                    )
                else:
                    fh = CompressedWriter(
                        fh,
                        data_out=self.compress,
                        data_out_name=self.data_out_name,
                        num_lines=self.num_lines,
                    )

            if self.sort:
                fh = SortingWriter(fh)

            self.output = fh

            super().process_all()

            if self.sort or self.compress:
                fh.flush()
                if data_out:
                    data_out.close()

    def _resolve_rel_path(self, filename):
        if not self.dir_root:
            return os.path.basename(filename)

        path = os.path.relpath(filename, self.dir_root)
        if os.path.sep != "/":  # pragma: no cover
            path = path.replace(os.path.sep, "/")
        return path

    def process_one(self, input_, output, filename):
        self.curr_filename = self.force_filename or self._resolve_rel_path(
            filename)

        it = self._create_record_iter(input_)

        self._write_header(output, filename)

        if self.collect_records:
            wrap_it = self.req_resolving_iter(it)
        else:
            wrap_it = it

        for record in wrap_it:
            if not self.include_records or record.rec_type in self.include_records:
                self.process_index_entry(it, record, filename, output)

    def _get_digest(self, record, name):
        value = record.rec_headers.get(name)
        if not value:
            if not self.writer:
                self.writer = BufferWARCWriter()

            self.writer.ensure_digest(record, block=False, payload=True)
            value = record.rec_headers.get(name)

        if value:
            value = value.split(":")[-1]
        return value

    def _write_line(self, out, index, record, filename):
        url = index.get("url")
        if not url:
            url = record.rec_headers.get("WARC-Target-URI")

        dt = record.rec_headers.get("WARC-Date")

        ts = iso_date_to_timestamp(dt)

        if hasattr(record, "urlkey"):
            urlkey = record.urlkey
        else:
            urlkey = self.get_url_key(url)

        self._do_write(urlkey, ts, index, out)

    def _do_write(self, urlkey, ts, index, out):
        out.write(urlkey + " " + ts + " " + json.dumps(index) + "\n")

    def get_url_key(self, url):
        try:
            return surt.surt(url)
        except:  # pragma: no coverage
            return url

    def _concur_req_resp(self, rec_1, rec_2):
        if not rec_1 or not rec_2:
            return None, None

        if rec_1.rec_headers.get_header(
                "WARC-Target-URI") != rec_2.rec_headers.get_header(
                    "WARC-Target-URI"):
            return None, None

        if rec_2.rec_headers.get_header(
                "WARC-Concurrent-To") != rec_1.rec_headers.get_header(
                    "WARC-Record-ID"):
            return None, None

        if rec_1.rec_type == "response" and rec_2.rec_type == "request":
            req = rec_2
            resp = rec_1

        elif rec_1.rec_type == "request" and rec_2.rec_type == "response":
            req = rec_1
            resp = rec_2

        else:
            return None, None

        return req, resp

    def req_resolving_iter(self, record_iter):
        prev_record = None

        for record in record_iter:
            if record.rec_type == "request":
                record.buffered_stream = BytesIO(
                    record.content_stream().read())

            record.file_offset = record_iter.get_record_offset()
            record.file_length = record_iter.get_record_length()

            req, resp = self._concur_req_resp(prev_record, record)

            if not req or not resp:
                if prev_record:
                    yield prev_record
                prev_record = record
                continue

            self._join_req_resp(req, resp)

            yield prev_record
            yield record
            prev_record = None

        if prev_record:
            yield prev_record

    def _join_req_resp(self, req, resp):
        resp.req = req

        method = req.http_headers.protocol
        if self.post_append and method.upper() in ("POST", "PUT"):
            post_url = append_post_query(req, resp)
            resp.urlkey = self.get_url_key(post_url)
            req.urlkey = resp.urlkey
コード例 #3
0
ファイル: main.py プロジェクト: donfanning/cdxj-indexer
class CDXJIndexer(Indexer):
    field_names = {
        "warc-target-uri": "url",
        "http:status": "status",
        "warc-payload-digest": "digest",
        "req.http:referer": "referrer",
        "req.http:method": "method",
        "record-digest": "recordDigest",
    }

    inv_field_names = {k: v for v, k in field_names.items()}

    DEFAULT_FIELDS = [
        "warc-target-uri",
        "mime",
        "http:status",
        "warc-payload-digest",
        "length",
        "offset",
        "filename",
    ]

    DEFAULT_RECORDS = ["response", "revisit", "resource", "metadata"]

    ALLOWED_EXT = (".arc", ".arc.gz", ".warc", ".warc.gz")

    RE_SPACE = re.compile(r"[;\s]")

    BUFF_SIZE = 1024 * 64

    DEFAULT_NUM_LINES = 300

    def __init__(self,
                 output,
                 inputs,
                 post_append=False,
                 sort=False,
                 compress=None,
                 lines=DEFAULT_NUM_LINES,
                 data_out_name=None,
                 filename=None,
                 fields=None,
                 replace_fields=None,
                 records=None,
                 verify_http=False,
                 dir_root=None,
                 digest_records=False,
                 **kwargs):

        if isinstance(inputs, str) or hasattr(inputs, "read"):
            inputs = [inputs]

        inputs = iter_file_or_dir(inputs)

        self.digest_records = digest_records
        fields = self._parse_fields(fields, replace_fields)

        super(CDXJIndexer, self).__init__(fields,
                                          inputs,
                                          output,
                                          verify_http=verify_http)
        self.writer = None

        self.curr_filename = None
        self.force_filename = filename
        self.post_append = post_append
        self.dir_root = dir_root

        self.num_lines = lines
        self.sort = sort
        self.compress = compress
        self.data_out_name = data_out_name

        self.include_records = records
        if self.include_records == "all":
            self.include_records = None
        elif self.include_records:
            self.include_records = self.include_records.split(",")
        else:
            self.include_records = self.DEFAULT_RECORDS

        self.collect_records = self.post_append or any(
            field.startswith("req.http:") for field in self.fields)
        self.record_parse = True

    def _parse_fields(self, fields=None, replace_fields=None):
        add_fields = replace_fields
        if add_fields:
            fields = []
        else:
            add_fields = fields
            fields = copy(self.DEFAULT_FIELDS)

        if self.digest_records and "record-digest" not in fields:
            fields.append("record-digest")

        if add_fields:
            add_fields = add_fields.split(",")
            for field in add_fields:
                fields.append(self.inv_field_names.get(field, field))

        return fields

    def get_field(self, record, name, it, filename):
        if name == "mime":
            if record.rec_type == "revisit":
                return "warc/revisit"
            elif record.rec_type in ("response", "request"):
                name = "http:content-type"
            else:
                name = "content-type"

            value = super(CDXJIndexer, self).get_field(record, name, it,
                                                       filename)
            if value:
                value = self.RE_SPACE.split(value, 1)[0].strip()

            return value

        if name == "filename":
            return self.curr_filename

        if self.collect_records:
            if name == "offset":
                return str(record.file_offset)
            elif name == "length":
                return str(record.file_length)
            elif name == "record-digest":
                return str(record.record_digest)
            elif name.startswith("req.http:"):
                value = self._get_req_field(name, record)
                if value:
                    return value

        value = super(CDXJIndexer, self).get_field(record, name, it, filename)

        if name == "warc-payload-digest":
            value = self._get_digest(record, name)

        return value

    def _get_req_field(self, name, record):
        if hasattr(record, "req"):
            req = record.req
        elif record.rec_type == "request":
            req = record
        else:
            return None

        if name == "req.http:method":
            return req.http_headers.protocol
        else:
            return req.http_headers.get_header(name[9:])

    def process_all(self):
        data_out = None

        with open_or_default(self.output, "wt", sys.stdout) as fh:
            if self.compress:
                if isinstance(self.compress, str):
                    data_out = open(self.compress, "wb")
                    if os.path.splitext(self.compress)[1] == "":
                        self.compress += ".cdxj.gz"

                    fh = CompressedWriter(
                        fh,
                        data_out=data_out,
                        data_out_name=self.compress,
                        num_lines=self.num_lines,
                        digest_records=self.digest_records,
                    )
                else:
                    fh = CompressedWriter(
                        fh,
                        data_out=self.compress,
                        data_out_name=self.data_out_name,
                        num_lines=self.num_lines,
                        digest_records=self.digest_records,
                    )

            if self.sort:
                fh = SortingWriter(fh)

            self.output = fh

            super().process_all()

            if self.sort or self.compress:
                fh.flush()
                if data_out:
                    data_out.close()

    def _resolve_rel_path(self, filename):
        if not self.dir_root:
            return os.path.basename(filename)

        path = os.path.relpath(filename, self.dir_root)
        if os.path.sep != "/":  # pragma: no cover
            path = path.replace(os.path.sep, "/")
        return path

    def process_one(self, input_, output, filename):
        self.curr_filename = self.force_filename or self._resolve_rel_path(
            filename)

        it = self._create_record_iter(input_)

        self._write_header(output, filename)

        if self.collect_records:
            wrap_it = self.req_resolving_iter(it, input_)
        else:
            wrap_it = it

        for record in wrap_it:
            if not self.include_records or self.filter_record(record):
                self.process_index_entry(it, record, filename, output)

    def filter_record(self, record):
        if not record.rec_type in self.include_records:
            return False

        if (self.include_records == self.DEFAULT_RECORDS
                and record.rec_type in ("resource", "metadata")
                and record.rec_headers.get_header("Content-Type")
                == "application/warc-fields"):
            return False

        return True

    def _get_digest(self, record, name):
        value = record.rec_headers.get(name)
        if not value:
            if not self.writer:
                self.writer = BufferWARCWriter()

            self.writer.ensure_digest(record, block=False, payload=True)
            value = record.rec_headers.get(name)

        return value

    def _write_line(self, out, index, record, filename):
        url = index.get("url")
        if not url:
            url = record.rec_headers.get("WARC-Target-URI")

        dt = record.rec_headers.get("WARC-Date")

        ts = iso_date_to_timestamp(dt)

        if hasattr(record, "urlkey"):
            urlkey = record.urlkey
        else:
            urlkey = self.get_url_key(url)

        if hasattr(record, "requestBody"):
            index["requestBody"] = record.requestBody
        if hasattr(record, "method"):
            index["method"] = record.method

        self._do_write(urlkey, ts, index, out)

    def _do_write(self, urlkey, ts, index, out):
        out.write(urlkey + " " + ts + " " + json.dumps(index) + "\n")

    def get_url_key(self, url):
        try:
            return surt.surt(url)
        except:  # pragma: no coverage
            return url

    def _concur_req_resp(self, rec_1, rec_2):
        if not rec_1 or not rec_2:
            return None, None

        if rec_1.rec_headers.get_header(
                "WARC-Target-URI") != rec_2.rec_headers.get_header(
                    "WARC-Target-URI"):
            return None, None

        if rec_2.rec_headers.get_header(
                "WARC-Concurrent-To") != rec_1.rec_headers.get_header(
                    "WARC-Record-ID"):
            return None, None

        if rec_1.rec_type == "response" and rec_2.rec_type == "request":
            req = rec_2
            resp = rec_1

        elif rec_1.rec_type == "request" and rec_2.rec_type == "response":
            req = rec_1
            resp = rec_2

        else:
            return None, None

        return req, resp

    def read_content(self, record):
        spool = tempfile.SpooledTemporaryFile()
        shutil.copyfileobj(record.content_stream(), spool)
        spool.seek(0)
        record.buffered_stream = spool
        # record.buffered_stream = BytesIO(record.content_stream().read())

    def req_resolving_iter(self, record_iter, digest_reader):
        prev_record = None

        for record in record_iter:

            # if record.rec_type == "request":
            self.read_content(record)

            record.file_offset = record_iter.get_record_offset()
            record.file_length = record_iter.get_record_length()

            if digest_reader and self.digest_records:
                curr = digest_reader.tell()
                digest_reader.seek(record.file_offset)
                record_digest, digest_length = self.digest_block(
                    digest_reader, record.file_length)
                digest_reader.seek(curr)

                if digest_length != record.file_length:
                    raise Exception(
                        "Digest block mismatch, expected {0}, got {1}",
                        record.file_length,
                        len(buff),
                    )

                record.record_digest = record_digest

            req, resp = self._concur_req_resp(prev_record, record)

            if not req or not resp:
                if prev_record:
                    yield prev_record
                    prev_record.buffered_stream.close()
                prev_record = record
                continue

            self._join_req_resp(req, resp)

            yield prev_record
            prev_record.buffered_stream.close()
            yield record
            record.buffered_stream.close()
            prev_record = None

        if prev_record:
            yield prev_record
            prev_record.buffered_stream.close()

    def _join_req_resp(self, req, resp):
        resp.req = req

        method = req.http_headers.protocol
        if self.post_append and method.upper() in ("POST", "PUT"):
            url = req.rec_headers.get_header("WARC-Target-URI")
            query, append_str = append_method_query_from_req_resp(req, resp)
            resp.method = method.upper()
            resp.requestBody = query
            resp.urlkey = self.get_url_key(url + append_str)
            req.urlkey = resp.urlkey

    def digest_block(self, reader, length):
        count = 0
        hasher = hashlib.sha256()

        while length > 0:
            buff = reader.read(min(self.BUFF_SIZE, length))
            if not buff:
                break
            hasher.update(buff)
            length -= len(buff)
            count += len(buff)

        return "sha256:" + hasher.hexdigest(), count
コード例 #4
0
class CDXJIndexer(Indexer):
    field_names = {
        "warc-target-uri": "url",
        "http:status": "status",
        "warc-payload-digest": "digest",
        "req.http:referer": "referrer",
        "req.http:method": "method",
        "record-digest": "recordDigest",
    }

    inv_field_names = {k: v for v, k in field_names.items()}

    DEFAULT_FIELDS = [
        "warc-target-uri",
        "mime",
        "http:status",
        "warc-payload-digest",
        "length",
        "offset",
        "filename",
    ]

    DEFAULT_RECORDS = ["response", "revisit", "resource", "metadata"]

    ALLOWED_EXT = (".arc", ".arc.gz", ".warc", ".warc.gz")

    RE_SPACE = re.compile(r"[;\s]")

    DEFAULT_NUM_LINES = 300

    def __init__(self,
                 output,
                 inputs,
                 post_append=False,
                 sort=False,
                 compress=None,
                 lines=DEFAULT_NUM_LINES,
                 max_sort_buff_size=None,
                 data_out_name=None,
                 filename=None,
                 fields=None,
                 replace_fields=None,
                 records=None,
                 verify_http=False,
                 dir_root=None,
                 digest_records=False,
                 **kwargs):

        if isinstance(inputs, str) or hasattr(inputs, "read"):
            inputs = [inputs]

        inputs = iter_file_or_dir(inputs)

        self.digest_records = digest_records
        fields = self._parse_fields(fields, replace_fields)

        super(CDXJIndexer, self).__init__(fields,
                                          inputs,
                                          output,
                                          verify_http=verify_http)
        self.writer = None

        self.curr_filename = None
        self.force_filename = filename
        self.post_append = post_append
        self.dir_root = dir_root

        self.num_lines = lines
        self.max_sort_buff_size = max_sort_buff_size
        self.sort = sort
        self.compress = compress
        self.data_out_name = data_out_name

        self.include_records = records
        if self.include_records == "all":
            self.include_records = None
        elif self.include_records:
            self.include_records = self.include_records.split(",")
        else:
            self.include_records = self.DEFAULT_RECORDS

        self.collect_records = self.post_append or any(
            field.startswith("req.http:") for field in self.fields)
        self.record_parse = True

    def _parse_fields(self, fields=None, replace_fields=None):
        add_fields = replace_fields
        if add_fields:
            fields = []
        else:
            add_fields = fields
            fields = copy(self.DEFAULT_FIELDS)

        if self.digest_records and "record-digest" not in fields:
            fields.append("record-digest")

        if add_fields:
            add_fields = add_fields.split(",")
            for field in add_fields:
                fields.append(self.inv_field_names.get(field, field))

        return fields

    def get_field(self, record, name, it, filename):
        if name == "mime":
            if record.rec_type == "revisit":
                return "warc/revisit"
            elif record.rec_type in ("response", "request"):
                name = "http:content-type"
            else:
                name = "content-type"

            value = super(CDXJIndexer, self).get_field(record, name, it,
                                                       filename)
            if value:
                value = self.RE_SPACE.split(value, 1)[0].strip()

            return value

        if name == "filename":
            return self.curr_filename

        if self.collect_records:
            if name == "offset":
                return str(record.file_offset)
            elif name == "length":
                return str(record.file_length)
            elif name == "record-digest":
                return str(record.record_digest)
            elif name.startswith("req.http:"):
                value = self._get_req_field(name, record)
                if value:
                    return value

        value = super(CDXJIndexer, self).get_field(record, name, it, filename)

        if name == "warc-payload-digest":
            value = self._get_digest(record, name)

        return value

    def _get_req_field(self, name, record):
        if hasattr(record, "req"):
            req = record.req
        elif record.rec_type == "request":
            req = record
        else:
            return None

        if name == "req.http:method":
            return req.http_headers.protocol
        else:
            return req.http_headers.get_header(name[9:])

    def process_all(self):
        data_out = None

        with open_or_default(self.output, "wt", sys.stdout) as fh:
            if self.compress:
                if isinstance(self.compress, str):
                    data_out = open(self.compress, "wb")
                    if os.path.splitext(self.compress)[1] == "":
                        self.compress += ".cdxj.gz"

                    fh = CompressedWriter(
                        fh,
                        data_out=data_out,
                        data_out_name=self.compress,
                        num_lines=self.num_lines,
                        digest_records=self.digest_records,
                    )
                else:
                    fh = CompressedWriter(
                        fh,
                        data_out=self.compress,
                        data_out_name=self.data_out_name,
                        num_lines=self.num_lines,
                        digest_records=self.digest_records,
                    )

            if self.sort:
                fh = SortingWriter(fh, self.max_sort_buff_size)

            self.output = fh

            super().process_all()

            if self.sort or self.compress:
                fh.flush()
                if data_out:
                    data_out.close()

    def _resolve_rel_path(self, filename):
        if not self.dir_root:
            return os.path.basename(filename)

        path = os.path.relpath(filename, self.dir_root)
        if os.path.sep != "/":  # pragma: no cover
            path = path.replace(os.path.sep, "/")
        return path

    def process_one(self, input_, output, filename):
        self.curr_filename = self.force_filename or self._resolve_rel_path(
            filename)

        it = self._create_record_iter(input_)

        self._write_header(output, filename)

        if self.collect_records:
            digest_reader = input_ if self.digest_records else None
            wrap_it = buffering_record_iter(
                it,
                post_append=self.post_append,
                digest_reader=digest_reader,
                url_key_func=self.get_url_key,
            )
        else:
            wrap_it = it

        for record in wrap_it:
            if not self.include_records or self.filter_record(record):
                self.process_index_entry(it, record, filename, output)

    def filter_record(self, record):
        if not record.rec_type in self.include_records:
            return False

        if (self.include_records == self.DEFAULT_RECORDS
                and record.rec_type in ("resource", "metadata")
                and record.rec_headers.get_header("Content-Type")
                == "application/warc-fields"):
            return False

        return True

    def _get_digest(self, record, name):
        value = record.rec_headers.get(name)
        if not value:
            if not self.writer:
                self.writer = BufferWARCWriter()

            self.writer.ensure_digest(record, block=False, payload=True)
            value = record.rec_headers.get(name)

        return value

    def _write_line(self, out, index, record, filename):
        url = index.get("url")
        if not url:
            url = record.rec_headers.get("WARC-Target-URI")

        dt = record.rec_headers.get("WARC-Date")

        ts = iso_date_to_timestamp(dt)

        if hasattr(record, "urlkey"):
            urlkey = record.urlkey
        else:
            urlkey = self.get_url_key(url)

        if hasattr(record, "requestBody"):
            index["requestBody"] = record.requestBody
        if hasattr(record, "method"):
            index["method"] = record.method

        self._do_write(urlkey, ts, index, out)

    def _do_write(self, urlkey, ts, index, out):
        out.write(urlkey + " " + ts + " " + json.dumps(index) + "\n")

    def get_url_key(self, url):
        try:
            return surt.surt(url)
        except:  # pragma: no coverage
            return url
コード例 #5
0
class CDXJIndexer(Indexer):
    field_names = {
        'warc-target-uri': 'url',
        'http:status': 'status',
        'warc-payload-digest': 'digest',
        'req.http:referer': 'referrer',
        'req.http:method': 'method',
    }

    inv_field_names = {k: v for v, k in field_names.items()}

    DEFAULT_FIELDS = [
        'warc-target-uri', 'mime', 'http:status', 'warc-payload-digest',
        'length', 'offset', 'filename'
    ]

    DEFAULT_RECORDS = ['response', 'revisit', 'resource', 'metadata']

    def __init__(self, output, inputs, opts=None):
        opts = opts or {}

        fields = self._parse_fields(opts)

        super(CDXJIndexer, self).__init__(fields, inputs, output)
        self.writer = None

        self.curr_filename = None
        self.force_filename = opts.get('filename')
        self.post_append = opts.get('post_append')

        self.write_records = opts.get('records')
        if self.write_records == 'all':
            self.write_records = None
        elif self.write_records:
            self.write_records = self.write_records.split(',')
        else:
            self.write_records = self.DEFAULT_RECORDS

        self.collect_records = self.post_append or any(
            field.startswith('req.http:') for field in self.fields)
        self.record_parse = True

    def _parse_fields(self, opts):
        add_fields = opts.get('replace_fields')

        if add_fields:
            fields = []
        else:
            add_fields = opts.get('fields')
            fields = copy(self.DEFAULT_FIELDS)

        if add_fields:
            add_fields = add_fields.split(',')
            for field in add_fields:
                fields.append(self.inv_field_names.get(field, field))

        return fields

    def get_field(self, record, name, it, filename):
        if name == 'mime':
            if record.rec_type == 'revisit':
                return 'warc/revisit'
            elif record.rec_type in ('response', 'request'):
                name = 'http:content-type'
            else:
                name = 'content-type'

            value = super(CDXJIndexer, self).get_field(record, name, it,
                                                       filename)
            if value:
                value = value.split(';')[0].strip()

            return value

        if name == 'filename':
            return self.curr_filename

        if self.collect_records:
            if name == 'offset':
                return str(record.file_offset)
            elif name == 'length':
                return str(record.file_length)
            elif name.startswith('req.http:'):
                value = self._get_req_field(name, record)
                if value:
                    return value

        value = super(CDXJIndexer, self).get_field(record, name, it, filename)

        if name == 'warc-payload-digest':
            value = self._get_digest(record, name)

        return value

    def _get_req_field(self, name, record):
        if hasattr(record, 'req'):
            req = record.req
        elif record.rec_type == 'request':
            req = record
        else:
            return None

        if name == 'req.http:method':
            return req.http_headers.protocol
        else:
            return req.http_headers.get_header(name[9:])

    def process_one(self, input_, output, filename):
        self.curr_filename = self.force_filename or os.path.basename(filename)

        it = self._create_record_iter(input_)

        self._write_header(output, filename)

        if self.collect_records:
            wrap_it = self.req_resolving_iter(it)
        else:
            wrap_it = it

        for record in wrap_it:
            if not self.write_records or record.rec_type in self.write_records:
                self.process_index_entry(it, record, filename, output)

    def _get_digest(self, record, name):
        value = record.rec_headers.get(name)
        if not value:
            if not self.writer:
                self.writer = BufferWARCWriter()

            self.writer.ensure_digest(record, block=False, payload=True)
            value = record.rec_headers.get(name)

        if value:
            value = value.split(':')[-1]
        return value

    def _write_line(self, out, index, record, filename):
        url = index.get('url')
        if not url:
            url = record.rec_headers.get('WARC-Target-URI')

        dt = record.rec_headers.get('WARC-Date')

        ts = iso_date_to_timestamp(dt)

        if hasattr(record, 'urlkey'):
            urlkey = record.urlkey
        else:
            urlkey = self.get_url_key(url)

        self._do_write(urlkey, ts, index, out)

    def _do_write(self, urlkey, ts, index, out):
        out.write(urlkey + ' ' + ts + ' ')
        out.write(json.dumps(index) + '\n')

    def get_url_key(self, url):
        try:
            return surt.surt(url)
        except:  #pragma: no coverage
            return url

    def _concur_req_resp(self, rec_1, rec_2):
        if not rec_1 or not rec_2:
            return None, None

        if (rec_1.rec_headers.get_header('WARC-Target-URI') !=
                rec_2.rec_headers.get_header('WARC-Target-URI')):
            return None, None

        if (rec_2.rec_headers.get_header('WARC-Concurrent-To') !=
                rec_1.rec_headers.get_header('WARC-Record-ID')):
            return None, None

        if rec_1.rec_type == 'response' and rec_2.rec_type == 'request':
            req = rec_2
            resp = rec_1

        elif rec_1.rec_type == 'request' and rec_2.rec_type == 'response':
            req = rec_1
            resp = rec_2

        else:
            return None, None

        return req, resp

    def req_resolving_iter(self, record_iter):
        prev_record = None

        for record in record_iter:
            if record.rec_type == 'request':
                record.buffered_stream = BytesIO(
                    record.content_stream().read())

            record.file_offset = record_iter.get_record_offset()
            record.file_length = record_iter.get_record_length()

            req, resp = self._concur_req_resp(prev_record, record)

            if not req or not resp:
                if prev_record:
                    yield prev_record
                prev_record = record
                continue

            self._join_req_resp(req, resp)

            yield prev_record
            yield record
            prev_record = None

        if prev_record:
            yield prev_record

    def _join_req_resp(self, req, resp):
        resp.req = req

        method = req.http_headers.protocol
        if self.post_append and method.upper() in ('POST', 'PUT'):
            post_url = append_post_query(req, resp)
            if post_url:
                resp.urlkey = self.get_url_key(post_url)
                req.urlkey = resp.urlkey