예제 #1
0
 def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager',
              sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None,
              status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredStreamClaim] = None,
              download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None,
              descriptor: typing.Optional[StreamDescriptor] = None,
              content_fee: typing.Optional['Transaction'] = None,
              analytics_manager: typing.Optional['AnalyticsManager'] = None):
     self.loop = loop
     self.config = config
     self.blob_manager = blob_manager
     self.sd_hash = sd_hash
     self.download_directory = download_directory
     self._file_name = file_name
     self._status = status
     self.stream_claim_info = claim
     self.download_id = download_id or binascii.hexlify(generate_id()).decode()
     self.rowid = rowid
     self.written_bytes = 0
     self.content_fee = content_fee
     self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor)
     self.analytics_manager = analytics_manager
     self.fully_reflected = asyncio.Event(loop=self.loop)
     self.file_output_task: typing.Optional[asyncio.Task] = None
     self.delayed_stop: typing.Optional[asyncio.Handle] = None
     self.saving = asyncio.Event(loop=self.loop)
     self.finished_writing = asyncio.Event(loop=self.loop)
     self.started_writing = asyncio.Event(loop=self.loop)
예제 #2
0
 async def start_downloader(
         self,
         got_descriptor_time: asyncio.Future,
         downloader: StreamDownloader,
         download_id: str,
         outpoint: str,
         claim: Claim,
         resolved: typing.Dict,
         file_name: typing.Optional[str] = None) -> ManagedStream:
     start_time = self.loop.time()
     downloader.download(self.node)
     await downloader.got_descriptor.wait()
     got_descriptor_time.set_result(self.loop.time() - start_time)
     rowid = await self._store_stream(downloader)
     await self.storage.save_content_claim(
         downloader.descriptor.stream_hash, outpoint)
     stream = ManagedStream(self.loop,
                            self.blob_manager,
                            rowid,
                            downloader.descriptor,
                            self.config.download_dir,
                            file_name,
                            downloader,
                            ManagedStream.STATUS_RUNNING,
                            download_id=download_id)
     stream.set_claim(resolved, claim)
     await stream.downloader.wrote_bytes_event.wait()
     self.streams.add(stream)
     return stream
예제 #3
0
 async def setup_stream(self, blob_count: int = 10):
     self.stream_bytes = b''
     for _ in range(blob_count):
         self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
     # create the stream
     file_path = os.path.join(self.server_dir, "test_file")
     with open(file_path, 'wb') as f:
         f.write(self.stream_bytes)
     descriptor = await StreamDescriptor.create_stream(
         self.loop, self.server_blob_manager.blob_dir, file_path)
     self.sd_hash = descriptor.calculate_sd_hash()
     self.downloader = StreamDownloader(self.loop, self.client_blob_manager,
                                        self.sd_hash, 3, 3, self.client_dir)
예제 #4
0
    async def _download_stream_from_claim(self, node: 'Node', download_directory: str, claim_info: typing.Dict,
                                          file_name: typing.Optional[str] = None) -> typing.Optional[ManagedStream]:

        claim = smart_decode(claim_info['value'])
        downloader = StreamDownloader(self.loop, self.config, self.blob_manager, claim.source_hash.decode(),
                                      download_directory, file_name)
        try:
            downloader.download(node)
            await downloader.got_descriptor.wait()
            log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
        except (asyncio.TimeoutError, asyncio.CancelledError):
            log.info("stream timeout")
            downloader.stop()
            log.info("stopped stream")
            raise DownloadSDTimeout(downloader.sd_hash)
        rowid = await self._store_stream(downloader)
        await self.storage.save_content_claim(
            downloader.descriptor.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}"
        )
        stream = ManagedStream(self.loop, self.blob_manager, rowid, downloader.descriptor, download_directory,
                               file_name, downloader, ManagedStream.STATUS_RUNNING)
        stream.set_claim(claim_info, claim)
        self.streams.add(stream)
        try:
            await stream.downloader.wrote_bytes_event.wait()
            self.wait_for_stream_finished(stream)
            return stream
        except asyncio.CancelledError:
            downloader.stop()
            log.debug("stopped stream")
        await self.stop_stream(stream)
        raise DownloadDataTimeout(downloader.sd_hash)
예제 #5
0
    async def _download_stream_from_claim(self, node: 'Node', download_directory: str, claim_info: typing.Dict,
                                          file_name: typing.Optional[str] = None) -> typing.Optional[ManagedStream]:

        claim = ClaimDict.load_dict(claim_info['value'])
        downloader = StreamDownloader(self.loop, self.blob_manager, claim.source_hash.decode(), self.peer_timeout,
                                      self.peer_connect_timeout, download_directory, file_name, self.fixed_peers)
        try:
            downloader.download(node)
            await downloader.got_descriptor.wait()
            log.info("got descriptor %s for %s", claim.source_hash.decode(), claim_info['name'])
        except (asyncio.TimeoutError, asyncio.CancelledError):
            log.info("stream timeout")
            await downloader.stop()
            log.info("stopped stream")
            return
        if not await self.blob_manager.storage.stream_exists(downloader.sd_hash):
            await self.blob_manager.storage.store_stream(downloader.sd_blob, downloader.descriptor)
        if not await self.blob_manager.storage.file_exists(downloader.sd_hash):
            await self.blob_manager.storage.save_downloaded_file(
                downloader.descriptor.stream_hash, os.path.basename(downloader.output_path), download_directory,
                0.0
            )
        await self.blob_manager.storage.save_content_claim(
            downloader.descriptor.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}"
        )

        stored_claim = StoredStreamClaim(
            downloader.descriptor.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}", claim_info['claim_id'],
            claim_info['name'], claim_info['amount'], claim_info['height'], claim_info['hex'],
            claim.certificate_id, claim_info['address'], claim_info['claim_sequence'],
            claim_info.get('channel_name')
        )
        stream = ManagedStream(self.loop, self.blob_manager, downloader.descriptor, download_directory,
                               os.path.basename(downloader.output_path), downloader, ManagedStream.STATUS_RUNNING,
                               stored_claim)
        self.streams.add(stream)
        try:
            await stream.downloader.wrote_bytes_event.wait()
            self.wait_for_stream_finished(stream)
            return stream
        except asyncio.CancelledError:
            await downloader.stop()
예제 #6
0
 async def load_streams_from_database(self):
     infos = await self.storage.get_all_lbry_files()
     for file_info in infos:
         sd_blob = self.blob_manager.get_blob(file_info['sd_hash'])
         if sd_blob.get_is_verified():
             descriptor = await self.blob_manager.get_stream_descriptor(sd_blob.blob_hash)
             downloader = StreamDownloader(
                 self.loop, self.blob_manager, descriptor.sd_hash, self.peer_timeout,
                 self.peer_connect_timeout, binascii.unhexlify(file_info['download_directory']).decode(),
                 binascii.unhexlify(file_info['file_name']).decode(), self.fixed_peers
             )
             stream = ManagedStream(
                 self.loop, self.blob_manager, descriptor,
                 binascii.unhexlify(file_info['download_directory']).decode(),
                 binascii.unhexlify(file_info['file_name']).decode(),
                 downloader, file_info['status'], file_info['claim']
             )
             self.streams.add(stream)
예제 #7
0
class ManagedStream:
    STATUS_RUNNING = "running"
    STATUS_STOPPED = "stopped"
    STATUS_FINISHED = "finished"

    __slots__ = [
        'loop', 'config', 'blob_manager', 'sd_hash', 'download_directory',
        '_file_name', '_status', 'stream_claim_info', 'download_id', 'rowid',
        'written_bytes', 'content_fee', 'downloader', 'analytics_manager',
        'fully_reflected', 'file_output_task', 'delayed_stop_task',
        'streaming_responses', 'streaming', '_running', 'saving',
        'finished_writing', 'started_writing', 'finished_write_attempt'
    ]

    def __init__(
            self,
            loop: asyncio.BaseEventLoop,
            config: 'Config',
            blob_manager: 'BlobManager',
            sd_hash: str,
            download_directory: typing.Optional[str] = None,
            file_name: typing.Optional[str] = None,
            status: typing.Optional[str] = STATUS_STOPPED,
            claim: typing.Optional[StoredStreamClaim] = None,
            download_id: typing.Optional[str] = None,
            rowid: typing.Optional[int] = None,
            descriptor: typing.Optional[StreamDescriptor] = None,
            content_fee: typing.Optional['Transaction'] = None,
            analytics_manager: typing.Optional['AnalyticsManager'] = None):
        self.loop = loop
        self.config = config
        self.blob_manager = blob_manager
        self.sd_hash = sd_hash
        self.download_directory = download_directory
        self._file_name = file_name
        self._status = status
        self.stream_claim_info = claim
        self.download_id = download_id or binascii.hexlify(
            generate_id()).decode()
        self.rowid = rowid
        self.written_bytes = 0
        self.content_fee = content_fee
        self.downloader = StreamDownloader(self.loop, self.config,
                                           self.blob_manager, sd_hash,
                                           descriptor)
        self.analytics_manager = analytics_manager

        self.fully_reflected = asyncio.Event(loop=self.loop)
        self.file_output_task: typing.Optional[asyncio.Task] = None
        self.delayed_stop_task: typing.Optional[asyncio.Task] = None
        self.streaming_responses: typing.List[typing.Tuple[
            Request, StreamResponse]] = []
        self.streaming = asyncio.Event(loop=self.loop)
        self._running = asyncio.Event(loop=self.loop)
        self.saving = asyncio.Event(loop=self.loop)
        self.finished_writing = asyncio.Event(loop=self.loop)
        self.started_writing = asyncio.Event(loop=self.loop)
        self.finished_write_attempt = asyncio.Event(loop=self.loop)

    @property
    def descriptor(self) -> StreamDescriptor:
        return self.downloader.descriptor

    @property
    def stream_hash(self) -> str:
        return self.descriptor.stream_hash

    @property
    def file_name(self) -> typing.Optional[str]:
        return self._file_name or (self.descriptor.suggested_file_name
                                   if self.descriptor else None)

    @property
    def status(self) -> str:
        return self._status

    async def update_status(self, status: str):
        assert status in [
            self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED
        ]
        self._status = status
        await self.blob_manager.storage.change_file_status(
            self.stream_hash, status)

    @property
    def finished(self) -> bool:
        return self.status == self.STATUS_FINISHED

    @property
    def running(self) -> bool:
        return self.status == self.STATUS_RUNNING

    @property
    def claim_id(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim_id

    @property
    def txid(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.txid

    @property
    def nout(self) -> typing.Optional[int]:
        return None if not self.stream_claim_info else self.stream_claim_info.nout

    @property
    def outpoint(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.outpoint

    @property
    def claim_height(self) -> typing.Optional[int]:
        return None if not self.stream_claim_info else self.stream_claim_info.height

    @property
    def channel_claim_id(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.channel_claim_id

    @property
    def channel_name(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.channel_name

    @property
    def claim_name(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim_name

    @property
    def metadata(self) -> typing.Optional[typing.Dict]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim.stream.to_dict(
        )

    @property
    def metadata_protobuf(self) -> bytes:
        if self.stream_claim_info:
            return binascii.hexlify(self.stream_claim_info.claim.to_bytes())

    @property
    def blobs_completed(self) -> int:
        return sum([
            1 if self.blob_manager.is_blob_verified(b.blob_hash) else 0
            for b in self.descriptor.blobs[:-1]
        ])

    @property
    def blobs_in_stream(self) -> int:
        return len(self.descriptor.blobs) - 1

    @property
    def blobs_remaining(self) -> int:
        return self.blobs_in_stream - self.blobs_completed

    @property
    def full_path(self) -> typing.Optional[str]:
        return os.path.join(self.download_directory, os.path.basename(self.file_name)) \
            if self.file_name and self.download_directory else None

    @property
    def output_file_exists(self):
        return os.path.isfile(self.full_path) if self.full_path else False

    @property
    def mime_type(self):
        return guess_media_type(
            os.path.basename(self.descriptor.suggested_file_name))[0]

    def as_dict(self) -> typing.Dict:
        full_path = self.full_path
        file_name = self.file_name
        download_directory = self.download_directory
        if self.full_path and self.output_file_exists:
            if self.written_bytes:
                written_bytes = self.written_bytes
            else:
                written_bytes = os.stat(self.full_path).st_size
        else:
            full_path = None
            file_name = None
            download_directory = None
            written_bytes = None
        return {
            'streaming_url':
            f"http://{self.config.streaming_host}:{self.config.streaming_port}/stream/{self.sd_hash}",
            'completed': (self.output_file_exists
                          and self.status in ('stopped', 'finished')) or all(
                              self.blob_manager.is_blob_verified(b.blob_hash)
                              for b in self.descriptor.blobs[:-1]),
            'file_name':
            file_name,
            'download_directory':
            download_directory,
            'points_paid':
            0.0,
            'stopped':
            not self.running,
            'stream_hash':
            self.stream_hash,
            'stream_name':
            self.descriptor.stream_name,
            'suggested_file_name':
            self.descriptor.suggested_file_name,
            'sd_hash':
            self.descriptor.sd_hash,
            'download_path':
            full_path,
            'mime_type':
            self.mime_type,
            'key':
            self.descriptor.key,
            'total_bytes_lower_bound':
            self.descriptor.lower_bound_decrypted_length(),
            'total_bytes':
            self.descriptor.upper_bound_decrypted_length(),
            'written_bytes':
            written_bytes,
            'blobs_completed':
            self.blobs_completed,
            'blobs_in_stream':
            self.blobs_in_stream,
            'blobs_remaining':
            self.blobs_remaining,
            'status':
            self.status,
            'claim_id':
            self.claim_id,
            'txid':
            self.txid,
            'nout':
            self.nout,
            'outpoint':
            self.outpoint,
            'metadata':
            self.metadata,
            'protobuf':
            self.metadata_protobuf,
            'channel_claim_id':
            self.channel_claim_id,
            'channel_name':
            self.channel_name,
            'claim_name':
            self.claim_name,
            'content_fee':
            self.content_fee
        }

    @classmethod
    async def create(
        cls,
        loop: asyncio.BaseEventLoop,
        config: 'Config',
        blob_manager: 'BlobManager',
        file_path: str,
        key: typing.Optional[bytes] = None,
        iv_generator: typing.Optional[typing.Generator[bytes, None,
                                                       None]] = None
    ) -> 'ManagedStream':
        descriptor = await StreamDescriptor.create_stream(
            loop,
            blob_manager.blob_dir,
            file_path,
            key=key,
            iv_generator=iv_generator,
            blob_completed_callback=blob_manager.blob_completed)
        await blob_manager.storage.store_stream(
            blob_manager.get_blob(descriptor.sd_hash), descriptor)
        row_id = await blob_manager.storage.save_published_file(
            descriptor.stream_hash, os.path.basename(file_path),
            os.path.dirname(file_path), 0)
        return cls(loop,
                   config,
                   blob_manager,
                   descriptor.sd_hash,
                   os.path.dirname(file_path),
                   os.path.basename(file_path),
                   status=cls.STATUS_FINISHED,
                   rowid=row_id,
                   descriptor=descriptor)

    async def start(self,
                    node: typing.Optional['Node'] = None,
                    timeout: typing.Optional[float] = None,
                    save_now: bool = False):
        timeout = timeout or self.config.download_timeout
        if self._running.is_set():
            return
        log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
        self._running.set()
        try:
            await asyncio.wait_for(self.downloader.start(node),
                                   timeout,
                                   loop=self.loop)
        except asyncio.TimeoutError:
            self._running.clear()
            raise DownloadSDTimeout(self.sd_hash)

        if self.delayed_stop_task and not self.delayed_stop_task.done():
            self.delayed_stop_task.cancel()
        self.delayed_stop_task = self.loop.create_task(self._delayed_stop())
        if not await self.blob_manager.storage.file_exists(self.sd_hash):
            if save_now:
                file_name, download_dir = self._file_name, self.download_directory
            else:
                file_name, download_dir = None, None
            self.rowid = await self.blob_manager.storage.save_downloaded_file(
                self.stream_hash, file_name, download_dir, 0.0)
        if self.status != self.STATUS_RUNNING:
            await self.update_status(self.STATUS_RUNNING)

    async def stop(self, finished: bool = False):
        """
        Stop any running save/stream tasks as well as the downloader and update the status in the database
        """

        self.stop_tasks()
        if (finished and self.status != self.STATUS_FINISHED
            ) or self.status == self.STATUS_RUNNING:
            await self.update_status(
                self.STATUS_FINISHED if finished else self.STATUS_STOPPED)

    async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
            -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
        if start_blob_num >= len(self.descriptor.blobs[:-1]):
            raise IndexError(start_blob_num)
        for i, blob_info in enumerate(
                self.descriptor.blobs[start_blob_num:-1]):
            assert i + start_blob_num == blob_info.blob_num
            decrypted = await self.downloader.read_blob(
                blob_info, connection_id)
            yield (blob_info, decrypted)

    async def stream_file(
            self,
            request: Request,
            node: typing.Optional['Node'] = None) -> StreamResponse:
        log.info("stream file to browser for lbry://%s#%s (sd hash %s...)",
                 self.claim_name, self.claim_id, self.sd_hash[:6])
        await self.start(node)
        headers, size, skip_blobs = self._prepare_range_response_headers(
            request.headers.get('range', 'bytes=0-'))
        response = StreamResponse(status=206, headers=headers)
        await response.prepare(request)
        self.streaming_responses.append((request, response))
        self.streaming.set()
        try:
            wrote = 0
            async for blob_info, decrypted in self._aiter_read_stream(
                    skip_blobs, connection_id=2):
                if (blob_info.blob_num == len(self.descriptor.blobs) -
                        2) or (len(decrypted) + wrote >= size):
                    decrypted += (b'\x00' * (size - len(decrypted) - wrote -
                                             (skip_blobs * 2097151)))
                    await response.write_eof(decrypted)
                else:
                    await response.write(decrypted)
                wrote += len(decrypted)
                log.info("sent browser %sblob %i/%i",
                         "(final) " if response._eof_sent else "",
                         blob_info.blob_num + 1,
                         len(self.descriptor.blobs) - 1)
                if response._eof_sent:
                    break
            return response
        finally:
            response.force_close()
            if (request, response) in self.streaming_responses:
                self.streaming_responses.remove((request, response))
            if not self.streaming_responses:
                self.streaming.clear()

    @staticmethod
    def _write_decrypted_blob(handle: typing.IO, data: bytes):
        handle.write(data)
        handle.flush()

    async def _save_file(self, output_path: str):
        log.info("save file for lbry://%s#%s (sd hash %s...) -> %s",
                 self.claim_name, self.claim_id, self.sd_hash[:6], output_path)
        self.saving.set()
        self.finished_write_attempt.clear()
        self.finished_writing.clear()
        self.started_writing.clear()
        try:
            with open(output_path, 'wb') as file_write_handle:
                async for blob_info, decrypted in self._aiter_read_stream(
                        connection_id=1):
                    log.info("write blob %i/%i", blob_info.blob_num + 1,
                             len(self.descriptor.blobs) - 1)
                    await self.loop.run_in_executor(None,
                                                    self._write_decrypted_blob,
                                                    file_write_handle,
                                                    decrypted)
                    self.written_bytes += len(decrypted)
                    if not self.started_writing.is_set():
                        self.started_writing.set()
            await self.update_status(ManagedStream.STATUS_FINISHED)
            if self.analytics_manager:
                self.loop.create_task(
                    self.analytics_manager.send_download_finished(
                        self.download_id, self.claim_name, self.sd_hash))
            self.finished_writing.set()
            log.info(
                "finished saving file for lbry://%s#%s (sd hash %s...) -> %s",
                self.claim_name, self.claim_id, self.sd_hash[:6],
                self.full_path)
            await self.blob_manager.storage.set_saved_file(self.stream_hash)
        except Exception as err:
            if os.path.isfile(output_path):
                log.warning("removing incomplete download %s for %s",
                            output_path, self.sd_hash)
                os.remove(output_path)
            self.written_bytes = 0
            if isinstance(err, asyncio.TimeoutError):
                self.downloader.stop()
                await self.blob_manager.storage.change_file_download_dir_and_file_name(
                    self.stream_hash, None, None)
                self._file_name, self.download_directory = None, None
                await self.blob_manager.storage.clear_saved_file(
                    self.stream_hash)
                await self.update_status(self.STATUS_STOPPED)
                return
            elif not isinstance(err, asyncio.CancelledError):
                log.exception(
                    "unexpected error encountered writing file for stream %s",
                    self.sd_hash)
            raise err
        finally:
            self.saving.clear()
            self.finished_write_attempt.set()

    async def save_file(self,
                        file_name: typing.Optional[str] = None,
                        download_directory: typing.Optional[str] = None,
                        node: typing.Optional['Node'] = None):
        await self.start(node)
        if self.file_output_task and not self.file_output_task.done(
        ):  # cancel an already running save task
            self.file_output_task.cancel()
        self.download_directory = download_directory or self.download_directory or self.config.download_dir
        if not self.download_directory:
            raise ValueError("no directory to download to")
        if not (file_name or self._file_name
                or self.descriptor.suggested_file_name):
            raise ValueError("no file name to download to")
        if not os.path.isdir(self.download_directory):
            log.warning(
                "download directory '%s' does not exist, attempting to make it",
                self.download_directory)
            os.mkdir(self.download_directory)
        self._file_name = await get_next_available_file_name(
            self.loop, self.download_directory, file_name
            or self.descriptor.suggested_file_name)
        await self.blob_manager.storage.change_file_download_dir_and_file_name(
            self.stream_hash, self.download_directory, self.file_name)
        await self.update_status(ManagedStream.STATUS_RUNNING)
        self.written_bytes = 0
        self.file_output_task = self.loop.create_task(
            self._save_file(self.full_path))
        await self.started_writing.wait()

    def stop_tasks(self):
        if self.file_output_task and not self.file_output_task.done():
            self.file_output_task.cancel()
        self.file_output_task = None
        while self.streaming_responses:
            req, response = self.streaming_responses.pop()
            response.force_close()
            req.transport.close()
        self.downloader.stop()
        self._running.clear()

    async def upload_to_reflector(self, host: str,
                                  port: int) -> typing.List[str]:
        sent = []
        protocol = StreamReflectorClient(self.blob_manager, self.descriptor)
        try:
            await self.loop.create_connection(lambda: protocol, host, port)
            await protocol.send_handshake()
            sent_sd, needed = await protocol.send_descriptor()
            if sent_sd:
                sent.append(self.sd_hash)
            if not sent_sd and not needed:
                if not self.fully_reflected.is_set():
                    self.fully_reflected.set()
                    await self.blob_manager.storage.update_reflected_stream(
                        self.sd_hash, f"{host}:{port}")
                    return []
            we_have = [
                blob_hash for blob_hash in needed
                if blob_hash in self.blob_manager.completed_blob_hashes
            ]
            for blob_hash in we_have:
                await protocol.send_blob(blob_hash)
                sent.append(blob_hash)
        except (asyncio.TimeoutError, ValueError):
            return sent
        except ConnectionRefusedError:
            return sent
        finally:
            if protocol.transport:
                protocol.transport.close()
        if not self.fully_reflected.is_set():
            self.fully_reflected.set()
            await self.blob_manager.storage.update_reflected_stream(
                self.sd_hash, f"{host}:{port}")
        return sent

    def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):
        self.stream_claim_info = StoredStreamClaim(
            self.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}",
            claim_info['claim_id'], claim_info['name'], claim_info['amount'],
            claim_info['height'],
            binascii.hexlify(claim.to_bytes()).decode(),
            claim.signing_channel_id, claim_info['address'],
            claim_info['claim_sequence'], claim_info.get('channel_name'))

    async def update_content_claim(self,
                                   claim_info: typing.Optional[
                                       typing.Dict] = None):
        if not claim_info:
            claim_info = await self.blob_manager.storage.get_content_claim(
                self.stream_hash)
        self.set_claim(claim_info, claim_info['value'])

    async def _delayed_stop(self):
        stalled_count = 0
        while self._running.is_set():
            if self.saving.is_set() or self.streaming.is_set():
                stalled_count = 0
            else:
                stalled_count += 1
            if stalled_count > 1:
                log.info("stopping inactive download for lbry://%s#%s (%s...)",
                         self.claim_name, self.claim_id, self.sd_hash[:6])
                await self.stop()
                return
            await asyncio.sleep(1, loop=self.loop)

    def _prepare_range_response_headers(
            self,
            get_range: str) -> typing.Tuple[typing.Dict[str, str], int, int]:
        if '=' in get_range:
            get_range = get_range.split('=')[1]
        start, end = get_range.split('-')
        size = 0

        for blob in self.descriptor.blobs[:-1]:
            size += blob.length - 1
        if self.stream_claim_info and self.stream_claim_info.claim.stream.source.size:
            size_from_claim = int(
                self.stream_claim_info.claim.stream.source.size)
            if not size_from_claim <= size <= size_from_claim + 16:
                raise ValueError("claim contains implausible stream size")
            log.debug("using stream size from claim")
            size = size_from_claim
        elif self.stream_claim_info:
            log.debug("estimating stream size")

        start = int(start)
        end = int(end) if end else size - 1
        skip_blobs = start // 2097150
        skip = skip_blobs * 2097151
        start = skip
        final_size = end - start + 1

        headers = {
            'Accept-Ranges': 'bytes',
            'Content-Range': f'bytes {start}-{end}/{size}',
            'Content-Length': str(final_size),
            'Content-Type': self.mime_type
        }
        return headers, size, skip_blobs
예제 #8
0
    async def _download_stream_from_uri(
            self,
            uri,
            timeout: float,
            exchange_rate_manager: 'ExchangeRateManager',
            file_name: typing.Optional[str] = None) -> ManagedStream:
        start_time = self.loop.time()
        parsed_uri = parse_lbry_uri(uri)
        if parsed_uri.is_channel:
            raise ResolveError(
                "cannot download a channel claim, specify a /path")

        # resolve the claim
        resolved = (await self.wallet.ledger.resolve(0, 10, uri)).get(uri, {})
        resolved = resolved if 'value' in resolved else resolved.get('claim')

        if not resolved:
            raise ResolveError(f"Failed to resolve stream at '{uri}'")
        if 'error' in resolved:
            raise ResolveError(f"error resolving stream: {resolved['error']}")

        claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf']))
        outpoint = f"{resolved['txid']}:{resolved['nout']}"
        resolved_time = self.loop.time() - start_time

        # resume or update an existing stream, if the stream changed download it and delete the old one after
        updated_stream, to_replace = await self._check_update_or_replace(
            outpoint, resolved['claim_id'], claim)
        if updated_stream:
            return updated_stream

        # check that the fee is payable
        fee_amount, fee_address = None, None
        if claim.stream.has_fee:
            fee_amount = round(
                exchange_rate_manager.convert_currency(
                    claim.stream.fee.currency, "LBC", claim.stream.fee.amount),
                5)
            max_fee_amount = round(
                exchange_rate_manager.convert_currency(
                    self.config.max_key_fee['currency'], "LBC",
                    Decimal(self.config.max_key_fee['amount'])), 5)
            if fee_amount > max_fee_amount:
                msg = f"fee of {fee_amount} exceeds max configured to allow of {max_fee_amount}"
                log.warning(msg)
                raise KeyFeeAboveMaxAllowed(msg)
            balance = await self.wallet.default_account.get_balance()
            if lbc_to_dewies(str(fee_amount)) > balance:
                msg = f"fee of {fee_amount} exceeds max available balance"
                log.warning(msg)
                raise InsufficientFundsError(msg)
            fee_address = claim.stream.fee.address

        # download the stream
        download_id = binascii.hexlify(generate_id()).decode()
        downloader = StreamDownloader(self.loop, self.config,
                                      self.blob_manager,
                                      claim.stream.source.sd_hash,
                                      self.config.download_dir, file_name)

        stream = None
        descriptor_time_fut = self.loop.create_future()
        start_download_time = self.loop.time()
        time_to_descriptor = None
        time_to_first_bytes = None
        error = None
        try:
            stream = await asyncio.wait_for(
                asyncio.ensure_future(
                    self.start_downloader(descriptor_time_fut, downloader,
                                          download_id, outpoint, claim,
                                          resolved, file_name)), timeout)
            time_to_descriptor = await descriptor_time_fut
            time_to_first_bytes = self.loop.time(
            ) - start_download_time - time_to_descriptor
            self.wait_for_stream_finished(stream)
            if fee_address and fee_amount and not to_replace:
                stream.tx = await self.wallet.send_amount_to_address(
                    lbc_to_dewies(str(fee_amount)),
                    fee_address.encode('latin1'))
            elif to_replace:  # delete old stream now that the replacement has started downloading
                await self.delete_stream(to_replace)
        except asyncio.TimeoutError:
            if descriptor_time_fut.done():
                time_to_descriptor = descriptor_time_fut.result()
                error = DownloadDataTimeout(downloader.sd_hash)
                self.blob_manager.delete_blob(downloader.sd_hash)
                await self.storage.delete_stream(downloader.descriptor)
            else:
                descriptor_time_fut.cancel()
                error = DownloadSDTimeout(downloader.sd_hash)
            if stream:
                await self.stop_stream(stream)
            else:
                downloader.stop()
        if error:
            log.warning(error)
        if self.analytics_manager:
            self.loop.create_task(
                self.analytics_manager.send_time_to_first_bytes(
                    resolved_time,
                    self.loop.time() - start_time, download_id,
                    parse_lbry_uri(uri).name, outpoint, None if not stream else
                    len(stream.downloader.blob_downloader.active_connections),
                    None if not stream else len(
                        stream.downloader.blob_downloader.scores),
                    False if not downloader else downloader.added_fixed_peers,
                    self.config.fixed_peer_delay
                    if not downloader else downloader.fixed_peers_delay,
                    claim.stream.source.sd_hash, time_to_descriptor,
                    None if not (stream and stream.descriptor) else
                    stream.descriptor.blobs[0].blob_hash,
                    None if not (stream and stream.descriptor) else
                    stream.descriptor.blobs[0].length, time_to_first_bytes,
                    None if not error else error.__class__.__name__))
        if error:
            raise error
        return stream
예제 #9
0
 def make_downloader(self, sd_hash: str, download_directory: str,
                     file_name: str):
     return StreamDownloader(self.loop, self.config, self.blob_manager,
                             sd_hash, download_directory, file_name)
예제 #10
0
class TestStreamDownloader(BlobExchangeTestBase):
    async def setup_stream(self, blob_count: int = 10):
        self.stream_bytes = b''
        for _ in range(blob_count):
            self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
        # create the stream
        file_path = os.path.join(self.server_dir, "test_file")
        with open(file_path, 'wb') as f:
            f.write(self.stream_bytes)
        descriptor = await StreamDescriptor.create_stream(
            self.loop, self.server_blob_manager.blob_dir, file_path)
        self.sd_hash = descriptor.calculate_sd_hash()
        conf = Config(data_dir=self.server_dir,
                      wallet_dir=self.server_dir,
                      download_dir=self.server_dir,
                      reflector_servers=[])
        self.downloader = StreamDownloader(self.loop, conf,
                                           self.client_blob_manager,
                                           self.sd_hash)

    async def _test_transfer_stream(self,
                                    blob_count: int,
                                    mock_accumulate_peers=None):
        await self.setup_stream(blob_count)
        mock_node = mock.Mock(spec=Node)

        def _mock_accumulate_peers(q1, q2):
            async def _task():
                pass

            q2.put_nowait([self.server_from_client])
            return q2, self.loop.create_task(_task())

        mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers

        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        self.downloader.stop()
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)
        await asyncio.sleep(0.01)
        self.assertTrue(self.downloader.stream_handle.closed)

    async def test_transfer_stream(self):
        await self._test_transfer_stream(10)

    @unittest.SkipTest
    async def test_transfer_hundred_blob_stream(self):
        await self._test_transfer_stream(100)

    async def test_transfer_stream_bad_first_peer_good_second(self):
        await self.setup_stream(2)

        mock_node = mock.Mock(spec=Node)
        q = asyncio.Queue()

        bad_peer = KademliaPeer(self.loop,
                                "127.0.0.1",
                                b'2' * 48,
                                tcp_port=3334)

        def _mock_accumulate_peers(q1, q2):
            async def _task():
                pass

            q2.put_nowait([bad_peer])
            self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
            return q2, self.loop.create_task(_task())

        mock_node.accumulate_peers = _mock_accumulate_peers

        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)
예제 #11
0
class ManagedStream:
    STATUS_RUNNING = "running"
    STATUS_STOPPED = "stopped"
    STATUS_FINISHED = "finished"

    def __init__(
            self,
            loop: asyncio.BaseEventLoop,
            config: 'Config',
            blob_manager: 'BlobManager',
            sd_hash: str,
            download_directory: typing.Optional[str] = None,
            file_name: typing.Optional[str] = None,
            status: typing.Optional[str] = STATUS_STOPPED,
            claim: typing.Optional[StoredStreamClaim] = None,
            download_id: typing.Optional[str] = None,
            rowid: typing.Optional[int] = None,
            descriptor: typing.Optional[StreamDescriptor] = None,
            content_fee: typing.Optional['Transaction'] = None,
            analytics_manager: typing.Optional['AnalyticsManager'] = None):
        self.loop = loop
        self.config = config
        self.blob_manager = blob_manager
        self.sd_hash = sd_hash
        self.download_directory = download_directory
        self._file_name = file_name
        self._status = status
        self.stream_claim_info = claim
        self.download_id = download_id or binascii.hexlify(
            generate_id()).decode()
        self.rowid = rowid
        self.written_bytes = 0
        self.content_fee = content_fee
        self.downloader = StreamDownloader(self.loop, self.config,
                                           self.blob_manager, sd_hash,
                                           descriptor)
        self.analytics_manager = analytics_manager
        self.fully_reflected = asyncio.Event(loop=self.loop)
        self.file_output_task: typing.Optional[asyncio.Task] = None
        self.delayed_stop: typing.Optional[asyncio.Handle] = None
        self.saving = asyncio.Event(loop=self.loop)
        self.finished_writing = asyncio.Event(loop=self.loop)
        self.started_writing = asyncio.Event(loop=self.loop)

    @property
    def descriptor(self) -> StreamDescriptor:
        return self.downloader.descriptor

    @property
    def stream_hash(self) -> str:
        return self.descriptor.stream_hash

    @property
    def file_name(self) -> typing.Optional[str]:
        return self._file_name or (self.descriptor.suggested_file_name
                                   if self.descriptor else None)

    @property
    def status(self) -> str:
        return self._status

    def update_status(self, status: str):
        assert status in [
            self.STATUS_RUNNING, self.STATUS_STOPPED, self.STATUS_FINISHED
        ]
        self._status = status

    @property
    def finished(self) -> bool:
        return self.status == self.STATUS_FINISHED

    @property
    def running(self) -> bool:
        return self.status == self.STATUS_RUNNING

    @property
    def claim_id(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim_id

    @property
    def txid(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.txid

    @property
    def nout(self) -> typing.Optional[int]:
        return None if not self.stream_claim_info else self.stream_claim_info.nout

    @property
    def outpoint(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.outpoint

    @property
    def claim_height(self) -> typing.Optional[int]:
        return None if not self.stream_claim_info else self.stream_claim_info.height

    @property
    def channel_claim_id(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.channel_claim_id

    @property
    def channel_name(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.channel_name

    @property
    def claim_name(self) -> typing.Optional[str]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim_name

    @property
    def metadata(self) -> typing.Optional[typing.Dict]:
        return None if not self.stream_claim_info else self.stream_claim_info.claim.stream.to_dict(
        )

    @property
    def metadata_protobuf(self) -> bytes:
        if self.stream_claim_info:
            return binascii.hexlify(self.stream_claim_info.claim.to_bytes())

    @property
    def blobs_completed(self) -> int:
        return sum([
            1 if self.blob_manager.is_blob_verified(b.blob_hash) else 0
            for b in self.descriptor.blobs[:-1]
        ])

    @property
    def blobs_in_stream(self) -> int:
        return len(self.descriptor.blobs) - 1

    @property
    def blobs_remaining(self) -> int:
        return self.blobs_in_stream - self.blobs_completed

    @property
    def full_path(self) -> typing.Optional[str]:
        return os.path.join(self.download_directory, os.path.basename(self.file_name)) \
            if self.file_name and self.download_directory else None

    @property
    def output_file_exists(self):
        return os.path.isfile(self.full_path) if self.full_path else False

    @property
    def mime_type(self):
        return guess_media_type(
            os.path.basename(self.descriptor.suggested_file_name))[0]

    def as_dict(self) -> typing.Dict:
        if self.written_bytes:
            written_bytes = self.written_bytes
        elif self.output_file_exists:
            written_bytes = os.stat(self.full_path).st_size
        else:
            written_bytes = None
        return {
            'completed':
            self.finished,
            'file_name':
            self.file_name,
            'download_directory':
            self.download_directory,
            'points_paid':
            0.0,
            'stopped':
            not self.running,
            'stream_hash':
            self.stream_hash,
            'stream_name':
            self.descriptor.stream_name,
            'suggested_file_name':
            self.descriptor.suggested_file_name,
            'sd_hash':
            self.descriptor.sd_hash,
            'download_path':
            self.full_path,
            'mime_type':
            self.mime_type,
            'key':
            self.descriptor.key,
            'total_bytes_lower_bound':
            self.descriptor.lower_bound_decrypted_length(),
            'total_bytes':
            self.descriptor.upper_bound_decrypted_length(),
            'written_bytes':
            written_bytes,
            'blobs_completed':
            self.blobs_completed,
            'blobs_in_stream':
            self.blobs_in_stream,
            'blobs_remaining':
            self.blobs_remaining,
            'status':
            self.status,
            'claim_id':
            self.claim_id,
            'txid':
            self.txid,
            'nout':
            self.nout,
            'outpoint':
            self.outpoint,
            'metadata':
            self.metadata,
            'protobuf':
            self.metadata_protobuf,
            'channel_claim_id':
            self.channel_claim_id,
            'channel_name':
            self.channel_name,
            'claim_name':
            self.claim_name,
            'content_fee':
            self.content_fee  # TODO: this isn't in the database
        }

    @classmethod
    async def create(
        cls,
        loop: asyncio.BaseEventLoop,
        config: 'Config',
        blob_manager: 'BlobManager',
        file_path: str,
        key: typing.Optional[bytes] = None,
        iv_generator: typing.Optional[typing.Generator[bytes, None,
                                                       None]] = None
    ) -> 'ManagedStream':
        descriptor = await StreamDescriptor.create_stream(
            loop,
            blob_manager.blob_dir,
            file_path,
            key=key,
            iv_generator=iv_generator,
            blob_completed_callback=blob_manager.blob_completed)
        await blob_manager.storage.store_stream(
            blob_manager.get_blob(descriptor.sd_hash), descriptor)
        row_id = await blob_manager.storage.save_published_file(
            descriptor.stream_hash, os.path.basename(file_path),
            os.path.dirname(file_path), 0)
        return cls(loop,
                   config,
                   blob_manager,
                   descriptor.sd_hash,
                   os.path.dirname(file_path),
                   os.path.basename(file_path),
                   status=cls.STATUS_FINISHED,
                   rowid=row_id,
                   descriptor=descriptor)

    async def setup(self,
                    node: typing.Optional['Node'] = None,
                    save_file: typing.Optional[bool] = True,
                    file_name: typing.Optional[str] = None,
                    download_directory: typing.Optional[str] = None):
        await self.downloader.start(node)
        if not save_file and not file_name:
            if not await self.blob_manager.storage.file_exists(self.sd_hash):
                self.rowid = await self.blob_manager.storage.save_downloaded_file(
                    self.stream_hash, None, None, 0.0)
                self.download_directory = None
                self._file_name = None
                self.update_status(ManagedStream.STATUS_RUNNING)
                await self.blob_manager.storage.change_file_status(
                    self.stream_hash, ManagedStream.STATUS_RUNNING)
            self.update_delayed_stop()
        elif not os.path.isfile(self.full_path):
            await self.save_file(file_name, download_directory)
            await self.started_writing.wait()

    def update_delayed_stop(self):
        def _delayed_stop():
            log.info("Stopping inactive download for stream %s", self.sd_hash)
            self.stop_download()

        if self.delayed_stop:
            self.delayed_stop.cancel()
        self.delayed_stop = self.loop.call_later(60, _delayed_stop)

    async def aiter_read_stream(
        self,
        start_blob_num: typing.Optional[int] = 0
    ) -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
        if start_blob_num >= len(self.descriptor.blobs[:-1]):
            raise IndexError(start_blob_num)
        for i, blob_info in enumerate(
                self.descriptor.blobs[start_blob_num:-1]):
            assert i + start_blob_num == blob_info.blob_num
            if self.delayed_stop:
                self.delayed_stop.cancel()
            try:
                decrypted = await self.downloader.read_blob(blob_info)
                yield (blob_info, decrypted)
            except asyncio.CancelledError:
                if not self.saving.is_set(
                ) and not self.finished_writing.is_set():
                    self.update_delayed_stop()
                raise

    async def _save_file(self, output_path: str):
        log.debug("save file %s -> %s", self.sd_hash, output_path)
        self.saving.set()
        self.finished_writing.clear()
        self.started_writing.clear()
        try:
            with open(output_path, 'wb') as file_write_handle:
                async for blob_info, decrypted in self.aiter_read_stream():
                    log.info("write blob %i/%i", blob_info.blob_num + 1,
                             len(self.descriptor.blobs) - 1)
                    file_write_handle.write(decrypted)
                    file_write_handle.flush()
                    self.written_bytes += len(decrypted)
                    if not self.started_writing.is_set():
                        self.started_writing.set()
            self.update_status(ManagedStream.STATUS_FINISHED)
            await self.blob_manager.storage.change_file_status(
                self.stream_hash, ManagedStream.STATUS_FINISHED)
            if self.analytics_manager:
                self.loop.create_task(
                    self.analytics_manager.send_download_finished(
                        self.download_id, self.claim_name, self.sd_hash))
            self.finished_writing.set()
        except Exception as err:
            if os.path.isfile(output_path):
                log.info("removing incomplete download %s for %s", output_path,
                         self.sd_hash)
                os.remove(output_path)
            if not isinstance(err, asyncio.CancelledError):
                log.exception(
                    "unexpected error encountered writing file for stream %s",
                    self.sd_hash)
            raise err
        finally:
            self.saving.clear()

    async def save_file(self,
                        file_name: typing.Optional[str] = None,
                        download_directory: typing.Optional[str] = None):
        if self.file_output_task and not self.file_output_task.done():
            self.file_output_task.cancel()
        if self.delayed_stop:
            self.delayed_stop.cancel()
            self.delayed_stop = None
        self.download_directory = download_directory or self.download_directory or self.config.download_dir
        if not self.download_directory:
            raise ValueError("no directory to download to")
        if not (file_name or self._file_name
                or self.descriptor.suggested_file_name):
            raise ValueError("no file name to download to")
        if not os.path.isdir(self.download_directory):
            log.warning(
                "download directory '%s' does not exist, attempting to make it",
                self.download_directory)
            os.mkdir(self.download_directory)
        self._file_name = await get_next_available_file_name(
            self.loop, self.download_directory, file_name or self._file_name
            or self.descriptor.suggested_file_name)
        if not await self.blob_manager.storage.file_exists(self.sd_hash):
            self.rowid = self.blob_manager.storage.save_downloaded_file(
                self.stream_hash, self.file_name, self.download_directory, 0.0)
        else:
            await self.blob_manager.storage.change_file_download_dir_and_file_name(
                self.stream_hash, self.download_directory, self.file_name)
        self.update_status(ManagedStream.STATUS_RUNNING)
        await self.blob_manager.storage.change_file_status(
            self.stream_hash, ManagedStream.STATUS_RUNNING)
        self.written_bytes = 0
        self.file_output_task = self.loop.create_task(
            self._save_file(self.full_path))

    def stop_download(self):
        if self.file_output_task and not self.file_output_task.done():
            self.file_output_task.cancel()
        self.file_output_task = None
        self.downloader.stop()

    async def upload_to_reflector(self, host: str,
                                  port: int) -> typing.List[str]:
        sent = []
        protocol = StreamReflectorClient(self.blob_manager, self.descriptor)
        try:
            await self.loop.create_connection(lambda: protocol, host, port)
            await protocol.send_handshake()
            sent_sd, needed = await protocol.send_descriptor()
            if sent_sd:
                sent.append(self.sd_hash)
            if not sent_sd and not needed:
                if not self.fully_reflected.is_set():
                    self.fully_reflected.set()
                    await self.blob_manager.storage.update_reflected_stream(
                        self.sd_hash, f"{host}:{port}")
                    return []
            we_have = [
                blob_hash for blob_hash in needed
                if blob_hash in self.blob_manager.completed_blob_hashes
            ]
            for blob_hash in we_have:
                await protocol.send_blob(blob_hash)
                sent.append(blob_hash)
        except (asyncio.TimeoutError, ValueError):
            return sent
        except ConnectionRefusedError:
            return sent
        finally:
            if protocol.transport:
                protocol.transport.close()
        if not self.fully_reflected.is_set():
            self.fully_reflected.set()
            await self.blob_manager.storage.update_reflected_stream(
                self.sd_hash, f"{host}:{port}")
        return sent

    def set_claim(self, claim_info: typing.Dict, claim: 'Claim'):
        self.stream_claim_info = StoredStreamClaim(
            self.stream_hash, f"{claim_info['txid']}:{claim_info['nout']}",
            claim_info['claim_id'], claim_info['name'], claim_info['amount'],
            claim_info['height'],
            binascii.hexlify(claim.to_bytes()).decode(),
            claim.signing_channel_id, claim_info['address'],
            claim_info['claim_sequence'], claim_info.get('channel_name'))
예제 #12
0
class TestStreamDownloader(BlobExchangeTestBase):
    async def setup_stream(self, blob_count: int = 10):
        self.stream_bytes = b''
        for _ in range(blob_count):
            self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
        # create the stream
        file_path = os.path.join(self.server_dir, "test_file")
        with open(file_path, 'wb') as f:
            f.write(self.stream_bytes)
        descriptor = await StreamDescriptor.create_stream(
            self.loop, self.server_blob_manager.blob_dir, file_path)
        self.sd_hash = descriptor.calculate_sd_hash()
        self.downloader = StreamDownloader(self.loop, self.client_blob_manager,
                                           self.sd_hash, 3, 3, self.client_dir)

    async def _test_transfer_stream(self,
                                    blob_count: int,
                                    mock_peer_search=None):
        await self.setup_stream(blob_count)

        mock_node = mock.Mock(spec=Node)

        @contextlib.asynccontextmanager
        async def _mock_peer_search(*_):
            async def _gen():
                yield [self.server_from_client]
                return

            yield _gen()

        mock_node.stream_peer_search_junction = mock_peer_search or _mock_peer_search

        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        await self.downloader.stop()
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)

    async def test_transfer_stream(self):
        await self._test_transfer_stream(10)

    # async def test_transfer_hundred_blob_stream(self):
    #     await self._test_transfer_stream(100)

    async def test_transfer_stream_bad_first_peer_good_second(self):
        await self.setup_stream(2)

        mock_node = mock.Mock(spec=Node)

        bad_peer = KademliaPeer(self.loop,
                                "127.0.0.1",
                                b'2' * 48,
                                tcp_port=3334)

        @contextlib.asynccontextmanager
        async def mock_peer_search(*_):
            async def _gen():
                await asyncio.sleep(0.05, loop=self.loop)
                yield [bad_peer]
                await asyncio.sleep(0.1, loop=self.loop)
                yield [self.server_from_client]
                return

            yield _gen()

        mock_node.stream_peer_search_junction = mock_peer_search

        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        await self.downloader.stop()
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)
예제 #13
0
class TestStreamDownloader(BlobExchangeTestBase):
    async def setup_stream(self, blob_count: int = 10):
        self.stream_bytes = b''
        for _ in range(blob_count):
            self.stream_bytes += os.urandom((MAX_BLOB_SIZE - 1))
        # create the stream
        file_path = os.path.join(self.server_dir, "test_file")
        with open(file_path, 'wb') as f:
            f.write(self.stream_bytes)
        descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
        self.sd_hash = descriptor.calculate_sd_hash()
        conf = Config(data_dir=self.server_dir, wallet_dir=self.server_dir, download_dir=self.server_dir,
                      reflector_servers=[])
        self.downloader = StreamDownloader(self.loop, conf, self.client_blob_manager, self.sd_hash)

    async def _test_transfer_stream(self, blob_count: int, mock_accumulate_peers=None):
        await self.setup_stream(blob_count)
        mock_node = mock.Mock(spec=Node)

        def _mock_accumulate_peers(q1, q2):
            async def _task():
                pass
            q2.put_nowait([self.server_from_client])
            return q2, self.loop.create_task(_task())

        mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers
        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        self.assertTrue(self.downloader.stream_handle.closed)
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        self.downloader.stop()
        self.assertIs(self.downloader.stream_handle, None)
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)
        await asyncio.sleep(0.01)

    async def test_transfer_stream(self):
        await self._test_transfer_stream(10)

    @unittest.SkipTest
    async def test_transfer_hundred_blob_stream(self):
        await self._test_transfer_stream(100)

    async def test_transfer_stream_bad_first_peer_good_second(self):
        await self.setup_stream(2)

        mock_node = mock.Mock(spec=Node)
        q = asyncio.Queue()

        bad_peer = KademliaPeer(self.loop, "127.0.0.1", b'2' * 48, tcp_port=3334)

        def _mock_accumulate_peers(q1, q2):
            async def _task():
                pass

            q2.put_nowait([bad_peer])
            self.loop.call_later(1, q2.put_nowait, [self.server_from_client])
            return q2, self.loop.create_task(_task())

        mock_node.accumulate_peers = _mock_accumulate_peers

        self.downloader.download(mock_node)
        await self.downloader.stream_finished_event.wait()
        self.assertTrue(os.path.isfile(self.downloader.output_path))
        with open(self.downloader.output_path, 'rb') as f:
            self.assertEqual(f.read(), self.stream_bytes)
        # self.assertIs(self.server_from_client.tcp_last_down, None)
        # self.assertIsNot(bad_peer.tcp_last_down, None)

    async def test_client_chunked_response(self):
        self.server.stop_server()
        class ChunkedServerProtocol(BlobServerProtocol):

            def send_response(self, responses):
                to_send = []
                while responses:
                    to_send.append(responses.pop())
                for byte in BlobResponse(to_send).serialize():
                    self.transport.write(bytes([byte]))
        self.server.server_protocol_class = ChunkedServerProtocol
        self.server.start_server(33333, '127.0.0.1')
        self.assertEqual(0, len(self.client_blob_manager.completed_blob_hashes))
        await asyncio.wait_for(self._test_transfer_stream(10), timeout=2)
        self.assertEqual(11, len(self.client_blob_manager.completed_blob_hashes))