コード例 #1
0
async def setup_harvester(port, dic={}):
    config = load_config(root_path, "config.yaml", "harvester")

    harvester = await Harvester.create(config, bt.plot_config)

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        harvester,
        NodeType.HARVESTER,
        ping_interval,
        network_id,
        root_path,
        config,
        f"harvester_server_{port}",
    )

    yield (harvester, server)

    harvester._shutdown()
    server.close_all()
    await harvester._await_shutdown()
    await server.await_closed()
コード例 #2
0
async def setup_farmer(port, dic={}):
    config = load_config(root_path, "config.yaml", "farmer")
    pool_sk = bt.pool_sk
    pool_target = create_puzzlehash_for_pk(
        BLSPublicKey(bytes(pool_sk.get_public_key())))
    wallet_sk = bt.wallet_sk
    wallet_target = create_puzzlehash_for_pk(
        BLSPublicKey(bytes(wallet_sk.get_public_key())))

    key_config = {
        "wallet_sk": bytes(wallet_sk).hex(),
        "wallet_target": wallet_target.hex(),
        "pool_sks": [bytes(pool_sk).hex()],
        "pool_target": pool_target.hex(),
    }
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    farmer = Farmer(config, key_config, test_constants_copy)
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(port, farmer, NodeType.FARMER, ping_interval,
                        network_id)
    _ = await server.start_server(farmer._on_connect)

    yield (farmer, server)

    server.close_all()
    await server.await_closed()
コード例 #3
0
async def setup_introducer(port, dic={}):
    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    config = load_config(root_path, "config.yaml", "introducer")

    introducer = Introducer(config["max_peers_to_send"],
                            config["recent_peer_threshold"])
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        introducer,
        NodeType.INTRODUCER,
        ping_interval,
        network_id,
        bt.root_path,
        config,
        f"introducer_server_{port}",
    )
    _ = await start_server(server)

    yield (introducer, server)

    _.close()
    server.close_all()
    await server.await_closed()
コード例 #4
0
ファイル: setup_nodes.py プロジェクト: Shar2216/Virtual
async def setup_two_nodes():
    """
    Setup and teardown of two full nodes, with blockchains and separate DBs.
    """

    # SETUP
    store_1 = await FullNodeStore.create("blockchain_test")
    store_2 = await FullNodeStore.create("blockchain_test_2")
    await store_1._clear_database()
    await store_2._clear_database()
    b_1: Blockchain = await Blockchain.create({}, test_constants)
    b_2: Blockchain = await Blockchain.create({}, test_constants)
    await store_1.add_block(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]))
    await store_2.add_block(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]))

    full_node_1 = FullNode(store_1, b_1)
    server_1 = ChiaServer(21234, full_node_1, NodeType.FULL_NODE)
    _ = await server_1.start_server("127.0.0.1", full_node_1._on_connect)
    full_node_1._set_server(server_1)

    full_node_2 = FullNode(store_2, b_2)
    server_2 = ChiaServer(21235, full_node_2, NodeType.FULL_NODE)
    full_node_2._set_server(server_2)

    yield (full_node_1, full_node_2, server_1, server_2)

    # TEARDOWN
    full_node_1._shutdown()
    full_node_2._shutdown()
    server_1.close_all()
    server_2.close_all()
    await server_1.await_closed()
    await server_2.await_closed()
    await store_1.close()
    await store_2.close()
コード例 #5
0
async def setup_introducer(port, dic={}):
    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    config = load_config(root_path, "config.yaml", "introducer")

    introducer = Introducer(config)
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        introducer,
        NodeType.INTRODUCER,
        ping_interval,
        network_id,
        bt.root_path,
        config,
    )
    _ = await server.start_server(None)

    yield (introducer, server)

    server.close_all()
    await server.await_closed()
コード例 #6
0
async def setup_full_node(db_name, port, introducer_port=None, dic={}):
    # SETUP
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]

    db_path = Path(db_name)
    connection = await aiosqlite.connect(db_path)
    store_1 = await FullNodeStore.create(connection)
    await store_1._clear_database()
    unspent_store_1 = await CoinStore.create(connection)
    await unspent_store_1._clear_database()
    mempool_1 = MempoolManager(unspent_store_1, test_constants_copy)

    b_1: Blockchain = await Blockchain.create(unspent_store_1, store_1,
                                              test_constants_copy)
    await mempool_1.new_tips(await b_1.get_full_tips())

    await store_1.add_block(
        FullBlock.from_bytes(test_constants_copy["GENESIS_BLOCK"]))

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    config = load_config(root_path, "config.yaml", "full_node")
    if introducer_port is not None:
        config["introducer_peer"]["host"] = "127.0.0.1"
        config["introducer_peer"]["port"] = introducer_port
    full_node_1 = FullNode(
        store_1,
        b_1,
        config,
        mempool_1,
        unspent_store_1,
        f"full_node_{port}",
        test_constants_copy,
    )
    assert ping_interval is not None
    assert network_id is not None
    server_1 = ChiaServer(
        port,
        full_node_1,
        NodeType.FULL_NODE,
        ping_interval,
        network_id,
        root_path,
        config,
    )
    _ = await server_1.start_server(full_node_1._on_connect)
    full_node_1._set_server(server_1)

    yield (full_node_1, server_1)

    # TEARDOWN
    full_node_1._shutdown()
    server_1.close_all()
    await connection.close()
    Path(db_name).unlink()
コード例 #7
0
async def setup_full_node_simulator(db_name,
                                    port,
                                    introducer_port=None,
                                    dic={}):
    # SETUP
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]

    db_path = root_path / f"{db_name}"
    if db_path.exists():
        db_path.unlink()

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    config = load_config(root_path, "config.yaml", "full_node")
    config["database_path"] = str(db_path)

    if introducer_port is not None:
        config["introducer_peer"]["host"] = "127.0.0.1"
        config["introducer_peer"]["port"] = introducer_port
    full_node_1 = await FullNodeSimulator.create(
        config=config,
        name=f"full_node_{port}",
        root_path=root_path,
        override_constants=test_constants_copy,
    )
    assert ping_interval is not None
    assert network_id is not None
    server_1 = ChiaServer(
        port,
        full_node_1,
        NodeType.FULL_NODE,
        ping_interval,
        network_id,
        bt.root_path,
        config,
        "full-node-simulator-server",
    )
    _ = await start_server(server_1, full_node_1._on_connect)
    full_node_1._set_server(server_1)

    yield (full_node_1, server_1)

    # TEARDOWN
    _.close()
    server_1.close_all()
    full_node_1._close()
    await server_1.await_closed()
    await full_node_1._await_closed()
    db_path.unlink()
コード例 #8
0
async def setup_wallet_node(port,
                            introducer_port=None,
                            key_seed=b"setup_wallet_node",
                            dic={}):
    config = load_config(root_path, "config.yaml", "wallet")
    if "starting_height" in dic:
        config["starting_height"] = dic["starting_height"]

    keychain = Keychain(key_seed.hex(), True)
    keychain.add_private_key_seed(key_seed)
    private_key = keychain.get_all_private_keys()[0][0]
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]
    db_path = root_path / f"test-wallet-db-{port}.db"
    if db_path.exists():
        db_path.unlink()
    config["database_path"] = str(db_path)

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    wallet = await WalletNode.create(
        config,
        private_key,
        root_path,
        override_constants=test_constants_copy,
        name="wallet1",
    )
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        wallet,
        NodeType.WALLET,
        ping_interval,
        network_id,
        root_path,
        config,
        "wallet-server",
    )
    wallet.set_server(server)

    yield (wallet, server)

    server.close_all()
    await wallet.wallet_state_manager.clear_all_stores()
    await wallet.wallet_state_manager.close_all_stores()
    wallet.wallet_state_manager.unlink_db()
    await server.await_closed()
コード例 #9
0
async def setup_timelord(port, dic={}):
    config = load_config(root_path, "config.yaml", "timelord")

    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]
    timelord = Timelord(config, test_constants_copy)

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        timelord,
        NodeType.TIMELORD,
        ping_interval,
        network_id,
        bt.root_path,
        config,
        f"timelord_server_{port}",
    )

    coro = asyncio.start_server(
        timelord._handle_client,
        config["vdf_server"]["host"],
        config["vdf_server"]["port"],
        loop=asyncio.get_running_loop(),
    )

    vdf_server = asyncio.ensure_future(coro)

    timelord.set_server(server)
    timelord._start_bg_tasks()

    async def run_timelord():
        async for msg in timelord._manage_discriminant_queue():
            server.push_message(msg)

    timelord_task = asyncio.create_task(run_timelord())

    yield (timelord, server)

    vdf_server.cancel()
    server.close_all()
    await timelord._shutdown()
    await timelord_task
    await server.await_closed()
コード例 #10
0
    async def test1(self):
        store = FullNodeStore("fndb_test")
        await store._clear_database()
        blocks = bt.get_consecutive_blocks(test_constants, 10, [], 10)
        b: Blockchain = Blockchain(test_constants)
        await store.add_block(blocks[0])
        await b.initialize({})
        for i in range(1, 9):
            assert (await
                    b.receive_block(blocks[i]
                                    )) == ReceiveBlockResult.ADDED_TO_HEAD
            await store.add_block(blocks[i])

        full_node_1 = FullNode(store, b)
        server_1 = ChiaServer(21234, full_node_1, NodeType.FULL_NODE)
        _ = await server_1.start_server("127.0.0.1", None)
        full_node_1._set_server(server_1)

        full_node_2 = FullNode(store, b)
        server_2 = ChiaServer(21235, full_node_2, NodeType.FULL_NODE)
        full_node_2._set_server(server_2)

        await server_2.start_client(PeerInfo("127.0.0.1", uint16(21234)), None)

        await asyncio.sleep(2)  # Allow connections to get made

        num_unfinished_blocks = 1000
        start_unf = time.time()
        for i in range(num_unfinished_blocks):
            msg = Message("unfinished_block",
                          peer_protocol.UnfinishedBlock(blocks[9]))
            server_1.push_message(
                OutboundMessage(NodeType.FULL_NODE, msg, Delivery.BROADCAST))

        # Send the whole block ast the end so we can detect when the node is done
        block_msg = Message("block", peer_protocol.Block(blocks[9]))
        server_1.push_message(
            OutboundMessage(NodeType.FULL_NODE, block_msg, Delivery.BROADCAST))

        while time.time() - start_unf < 300:
            if max([h.height for h in b.get_current_tips()]) == 9:
                print(
                    f"Time taken to process {num_unfinished_blocks} is {time.time() - start_unf}"
                )
                server_1.close_all()
                server_2.close_all()
                await server_1.await_closed()
                await server_2.await_closed()
                return
            await asyncio.sleep(0.1)

        server_1.close_all()
        server_2.close_all()
        await server_1.await_closed()
        await server_2.await_closed()
        raise Exception("Took too long to process blocks")
コード例 #11
0
async def setup_farmer(port, dic={}):
    print("root path", root_path)
    config = load_config(root_path, "config.yaml", "farmer")
    config_pool = load_config(root_path, "config.yaml", "pool")
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    config["xch_target_puzzle_hash"] = bt.fee_target.hex()
    config["pool_public_keys"] = [
        bytes(epk.get_public_key()).hex()
        for epk in bt.keychain.get_all_public_keys()
    ]
    config_pool["xch_target_puzzle_hash"] = bt.fee_target.hex()

    farmer = Farmer(config, config_pool, bt.keychain, test_constants_copy)
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(
        port,
        farmer,
        NodeType.FARMER,
        ping_interval,
        network_id,
        root_path,
        config,
        f"farmer_server_{port}",
    )
    farmer.set_server(server)
    _ = await start_server(server, farmer._on_connect)

    yield (farmer, server)

    _.close()
    server.close_all()
    await server.await_closed()
コード例 #12
0
async def setup_wallet_node(port, introducer_port=None, key_seed=b"", dic={}):
    config = load_config(root_path, "config.yaml", "wallet")
    if "starting_height" in dic:
        config["starting_height"] = dic["starting_height"]
    key_config = {
        "wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(key_seed)).hex(),
    }
    test_constants_copy = test_constants.copy()
    for k in dic.keys():
        test_constants_copy[k] = dic[k]
    db_path = root_path / ("test-wallet-db%s.db" % token_bytes(32).hex())
    if db_path.exists():
        db_path.unlink()
    config["database_path"] = str(db_path)

    net_config = load_config(root_path, "config.yaml")
    ping_interval = net_config.get("ping_interval")
    network_id = net_config.get("network_id")

    wallet = await WalletNode.create(
        config,
        key_config,
        override_constants=test_constants_copy,
        name="wallet1",
    )
    assert ping_interval is not None
    assert network_id is not None
    server = ChiaServer(port, wallet, NodeType.WALLET, ping_interval,
                        network_id, "wallet-server")
    wallet.set_server(server)

    yield (wallet, server)

    server.close_all()
    await wallet.wallet_state_manager.clear_all_stores()
    await wallet.wallet_state_manager.close_all_stores()
    wallet.wallet_state_manager.unlink_db()
    await server.await_closed()
コード例 #13
0
    async def test2(self):
        num_blocks = 100
        store = FullNodeStore("fndb_test")
        await store._clear_database()
        blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
        b: Blockchain = Blockchain(test_constants)
        await store.add_block(blocks[0])
        await b.initialize({})

        full_node_1 = FullNode(store, b)
        server_1 = ChiaServer(21236, full_node_1, NodeType.FULL_NODE)
        _ = await server_1.start_server("127.0.0.1", None)
        full_node_1._set_server(server_1)

        full_node_2 = FullNode(store, b)
        server_2 = ChiaServer(21237, full_node_2, NodeType.FULL_NODE)
        full_node_2._set_server(server_2)

        await server_2.start_client(PeerInfo("127.0.0.1", uint16(21236)), None)

        await asyncio.sleep(2)  # Allow connections to get made

        start_unf = time.time()
        for i in range(1, num_blocks):
            msg = Message("block", peer_protocol.Block(blocks[i]))
            server_1.push_message(
                OutboundMessage(NodeType.FULL_NODE, msg, Delivery.BROADCAST))

        while time.time() - start_unf < 300:
            if max([h.height for h in b.get_current_tips()]) == num_blocks - 1:
                print(
                    f"Time taken to process {num_blocks} is {time.time() - start_unf}"
                )
                server_1.close_all()
                server_2.close_all()
                await server_1.await_closed()
                await server_2.await_closed()
                return
            await asyncio.sleep(0.1)

        server_1.close_all()
        server_2.close_all()
        await server_1.await_closed()
        await server_2.await_closed()
        raise Exception("Took too long to process blocks")
コード例 #14
0
    async def test1(self):
        test_node_1_port = 21234
        test_node_2_port = 21235
        test_rpc_port = 21236
        db_filename = "blockchain_test"

        if os.path.isfile(db_filename):
            os.remove(db_filename)
        store = await FullNodeStore.create(db_filename)
        await store._clear_database()
        blocks = bt.get_consecutive_blocks(test_constants, 10, [], 10)
        b: Blockchain = await Blockchain.create({}, test_constants)
        await store.add_block(blocks[0])
        for i in range(1, 9):
            assert (await b.receive_block(
                blocks[i],
                blocks[i -
                       1].header_block)) == ReceiveBlockResult.ADDED_TO_HEAD
            await store.add_block(blocks[i])

        config = load_config("config.yaml", "full_node")
        full_node_1 = FullNode(store, b, config)
        server_1 = ChiaServer(test_node_1_port, full_node_1,
                              NodeType.FULL_NODE)
        _ = await server_1.start_server("127.0.0.1", None)
        full_node_1._set_server(server_1)

        def stop_node_cb():
            full_node_1._shutdown()
            server_1.close_all()

        rpc_cleanup = await start_rpc_server(full_node_1, stop_node_cb,
                                             test_rpc_port)

        try:
            client = await RpcClient.create(test_rpc_port)
            state = await client.get_blockchain_state()
            assert state["lca"].header_hash is not None
            assert not state["sync_mode"]
            assert len(state["tips"]) > 0
            assert state["difficulty"] > 0
            assert state["ips"] > 0

            block = await client.get_block(state["lca"].header_hash)
            assert block == blocks[6]
            assert (await client.get_block(bytes([1] * 32))) is None

            small_header_block = await client.get_header(
                state["lca"].header_hash)
            assert small_header_block.header == blocks[6].header_block.header

            assert len(await client.get_pool_balances()) > 0
            assert len(await client.get_connections()) == 0

            full_node_2 = FullNode(store, b, config)
            server_2 = ChiaServer(test_node_2_port, full_node_2,
                                  NodeType.FULL_NODE)
            full_node_2._set_server(server_2)

            _ = await server_2.start_server("127.0.0.1", None)
            await asyncio.sleep(2)  # Allow server to start
            cons = await client.get_connections()
            assert len(cons) == 0

            # Open a connection through the RPC
            await client.open_connection(host="127.0.0.1",
                                         port=test_node_2_port)
            cons = await client.get_connections()
            assert len(cons) == 1

            # Close a connection through the RPC
            await client.close_connection(cons[0]["node_id"])
            cons = await client.get_connections()
            assert len(cons) == 0
        except AssertionError:
            # Checks that the RPC manages to stop the node
            await client.stop_node()
            client.close()
            await client.await_closed()
            server_2.close_all()
            await server_1.await_closed()
            await server_2.await_closed()
            await rpc_cleanup()
            await store.close()
            raise

        await client.stop_node()
        client.close()
        await client.await_closed()
        server_2.close_all()
        await server_1.await_closed()
        await server_2.await_closed()
        await rpc_cleanup()
        await store.close()
コード例 #15
0
ファイル: start_service.py プロジェクト: spring3th/Exodus
class Service:
    def __init__(
        self,
        root_path,
        api: Any,
        node_type: NodeType,
        advertised_port: int,
        service_name: str,
        server_listen_ports: List[int] = [],
        connect_peers: List[PeerInfo] = [],
        on_connect_callback: Optional[OutboundMessage] = None,
        rpc_start_callback_port: Optional[Tuple[Callable, int]] = None,
        start_callback: Optional[Callable] = None,
        stop_callback: Optional[Callable] = None,
        await_closed_callback: Optional[Callable] = None,
        periodic_introducer_poll: Optional[Tuple[PeerInfo, int, int]] = None,
    ):
        net_config = load_config(root_path, "config.yaml")
        ping_interval = net_config.get("ping_interval")
        network_id = net_config.get("network_id")
        assert ping_interval is not None
        assert network_id is not None

        self._node_type = node_type

        proctitle_name = f"chia_{service_name}"
        setproctitle(proctitle_name)
        self._log = logging.getLogger(service_name)

        config = load_config_cli(root_path, "config.yaml", service_name)
        initialize_logging(f"{service_name:<30s}", config["logging"],
                           root_path)

        self._rpc_start_callback_port = rpc_start_callback_port

        self._server = ChiaServer(
            config["port"],
            api,
            node_type,
            ping_interval,
            network_id,
            root_path,
            config,
        )
        for _ in ["set_server", "_set_server"]:
            f = getattr(api, _, None)
            if f:
                f(self._server)

        self._connect_peers = connect_peers
        self._server_listen_ports = server_listen_ports

        self._api = api
        self._task = None
        self._is_stopping = False

        self._periodic_introducer_poll = periodic_introducer_poll
        self._on_connect_callback = on_connect_callback
        self._start_callback = start_callback
        self._stop_callback = stop_callback
        self._await_closed_callback = await_closed_callback

    def start(self):
        if self._task is not None:
            return

        async def _run():
            if self._start_callback:
                await self._start_callback()

            self._introducer_poll_task = None
            if self._periodic_introducer_poll:
                (
                    peer_info,
                    introducer_connect_interval,
                    target_peer_count,
                ) = self._periodic_introducer_poll

                self._introducer_poll_task = create_periodic_introducer_poll_task(
                    self._server,
                    peer_info,
                    self._server.global_connections,
                    introducer_connect_interval,
                    target_peer_count,
                )

            self._rpc_task = None
            if self._rpc_start_callback_port:
                rpc_f, rpc_port = self._rpc_start_callback_port
                self._rpc_task = asyncio.ensure_future(
                    rpc_f(self._api, self.stop, rpc_port))

            self._reconnect_tasks = [
                start_reconnect_task(self._server, _, self._log)
                for _ in self._connect_peers
            ]
            self._server_sockets = [
                await start_server(self._server, self._on_connect_callback)
                for _ in self._server_listen_ports
            ]

            try:
                asyncio.get_running_loop().add_signal_handler(
                    signal.SIGINT, self.stop)
                asyncio.get_running_loop().add_signal_handler(
                    signal.SIGTERM, self.stop)
            except NotImplementedError:
                self._log.info("signal handlers unsupported")

            for _ in self._server_sockets:
                await _.wait_closed()

            await self._server.await_closed()
            if self._await_closed_callback:
                await self._await_closed_callback()

        self._task = asyncio.ensure_future(_run())

    async def run(self):
        self.start()
        await self.wait_closed()
        self._log.info("Closed all node servers.")
        return 0

    def stop(self):
        if not self._is_stopping:
            self._is_stopping = True
            for _ in self._server_sockets:
                _.close()
            for _ in self._reconnect_tasks:
                _.cancel()
            self._server.close_all()
            self._api._shut_down = True
            if self._introducer_poll_task:
                self._introducer_poll_task.cancel()
            if self._stop_callback:
                self._stop_callback()

    async def wait_closed(self):
        await self._task
        if self._rpc_task:
            await self._rpc_task
            self._log.info("Closed RPC server.")
        self._log.info("%s fully closed", self._node_type)
コード例 #16
0
class Service:
    def __init__(
        self,
        root_path,
        node: Any,
        peer_api: Any,
        node_type: NodeType,
        advertised_port: int,
        service_name: str,
        network_id=bytes32,
        upnp_ports: List[int] = [],
        server_listen_ports: List[int] = [],
        connect_peers: List[PeerInfo] = [],
        auth_connect_peers: bool = True,
        on_connect_callback: Optional[Callable] = None,
        rpc_info: Optional[Tuple[type, int]] = None,
        parse_cli_args=True,
        connect_to_daemon=True,
    ):
        self.root_path = root_path
        self.config = load_config(root_path, "config.yaml")
        ping_interval = self.config.get("ping_interval")
        self.self_hostname = self.config.get("self_hostname")
        self.daemon_port = self.config.get("daemon_port")
        assert ping_interval is not None
        self._connect_to_daemon = connect_to_daemon
        self._node_type = node_type
        self._service_name = service_name
        self._rpc_task = None
        self._network_id: bytes32 = network_id

        proctitle_name = f"chia_{service_name}"
        setproctitle(proctitle_name)
        self._log = logging.getLogger(service_name)

        if parse_cli_args:
            service_config = load_config_cli(root_path, "config.yaml",
                                             service_name)
        else:
            service_config = load_config(root_path, "config.yaml",
                                         service_name)
        initialize_logging(service_name, service_config["logging"], root_path)

        self._rpc_info = rpc_info
        private_ca_crt, private_ca_key = private_ssl_ca_paths(
            root_path, self.config)
        chia_ca_crt, chia_ca_key = chia_ssl_ca_paths(root_path, self.config)
        self._server = ChiaServer(
            advertised_port,
            node,
            peer_api,
            node_type,
            ping_interval,
            network_id,
            root_path,
            service_config,
            (private_ca_crt, private_ca_key),
            (chia_ca_crt, chia_ca_key),
            name=f"{service_name}_server",
        )
        f = getattr(node, "set_server", None)
        if f:
            f(self._server)
        else:
            self._log.warning(f"No set_server method for {service_name}")

        self._connect_peers = connect_peers
        self._auth_connect_peers = auth_connect_peers
        self._upnp_ports = upnp_ports
        self._server_listen_ports = server_listen_ports

        self._api = peer_api
        self._node = node
        self._did_start = False
        self._is_stopping = asyncio.Event()
        self._stopped_by_rpc = False

        self._on_connect_callback = on_connect_callback
        self._advertised_port = advertised_port
        self._reconnect_tasks: List[asyncio.Task] = []

    async def start(self, **kwargs):
        # we include `kwargs` as a hack for the wallet, which for some
        # reason allows parameters to `_start`. This is serious BRAIN DAMAGE,
        # and should be fixed at some point.
        # TODO: move those parameters to `__init__`
        if self._did_start:
            return
        self._did_start = True

        self._enable_signals()

        await self._node._start(**kwargs)

        for port in self._upnp_ports:
            upnp_remap_port(port)

        await self._server.start_server(self._on_connect_callback)

        self._reconnect_tasks = [
            start_reconnect_task(self._server, _, self._log,
                                 self._auth_connect_peers)
            for _ in self._connect_peers
        ]
        self._log.info(
            f"Started {self._service_name} service on network_id: {self._network_id.hex()}"
        )

        self._rpc_close_task = None
        if self._rpc_info:
            rpc_api, rpc_port = self._rpc_info
            self._rpc_task = asyncio.create_task(
                start_rpc_server(
                    rpc_api(self._node),
                    self.self_hostname,
                    self.daemon_port,
                    rpc_port,
                    self.stop,
                    self.root_path,
                    self.config,
                    self._connect_to_daemon,
                ))

    async def run(self):
        await self.start()
        await self.wait_closed()

    def _enable_signals(self):
        signal.signal(signal.SIGINT, self._accept_signal)
        signal.signal(signal.SIGTERM, self._accept_signal)
        if platform == "win32" or platform == "cygwin":
            # pylint: disable=E1101
            signal.signal(signal.SIGBREAK, self._accept_signal)  # type: ignore

    def _accept_signal(self, signal_number: int, stack_frame):
        self._log.info(f"got signal {signal_number}")
        self.stop()

    def stop(self):
        if not self._is_stopping.is_set():
            self._is_stopping.set()
            self._log.info("Cancelling reconnect task")
            for _ in self._reconnect_tasks:
                _.cancel()
            self._log.info("Closing connections")
            self._server.close_all()
            self._node._close()
            self._node._shut_down = True

            self._log.info("Calling service stop callback")

            if self._rpc_task is not None:
                self._log.info("Closing RPC server")

                async def close_rpc_server():
                    await (await self._rpc_task)()

                self._rpc_close_task = asyncio.create_task(close_rpc_server())

    async def wait_closed(self):
        await self._is_stopping.wait()

        self._log.info("Waiting for socket to be closed (if opened)")

        self._log.info("Waiting for ChiaServer to be closed")
        await self._server.await_closed()

        if self._rpc_close_task:
            self._log.info("Waiting for RPC server")
            await self._rpc_close_task
            self._log.info("Closed RPC server")

        self._log.info("Waiting for service _await_closed callback")
        await self._node._await_closed()
        self._log.info(
            f"Service {self._service_name} at port {self._advertised_port} fully closed"
        )
コード例 #17
0
class Service:
    def __init__(
        self,
        root_path,
        api: Any,
        node_type: NodeType,
        advertised_port: int,
        service_name: str,
        server_listen_ports: List[int] = [],
        connect_peers: List[PeerInfo] = [],
        auth_connect_peers: bool = True,
        on_connect_callback: Optional[OnConnectFunc] = None,
        rpc_info: Optional[Tuple[type, int]] = None,
        start_callback: Optional[Callable] = None,
        stop_callback: Optional[Callable] = None,
        await_closed_callback: Optional[Callable] = None,
        parse_cli_args=True,
    ):
        net_config = load_config(root_path, "config.yaml")
        ping_interval = net_config.get("ping_interval")
        network_id = net_config.get("network_id")
        self.self_hostname = net_config.get("self_hostname")
        self.daemon_port = net_config.get("daemon_port")
        assert ping_interval is not None
        assert network_id is not None

        self._node_type = node_type
        self._service_name = service_name

        proctitle_name = f"chia_{service_name}"
        setproctitle(proctitle_name)
        self._log = logging.getLogger(service_name)
        if parse_cli_args:
            config = load_config_cli(root_path, "config.yaml", service_name)
        else:
            config = load_config(root_path, "config.yaml", service_name)
        initialize_logging(service_name, config["logging"], root_path)

        self._rpc_info = rpc_info

        self._server = ChiaServer(
            advertised_port,
            api,
            node_type,
            ping_interval,
            network_id,
            root_path,
            config,
            name=f"{service_name}_server",
        )
        for _ in ["set_server", "_set_server"]:
            f = getattr(api, _, None)
            if f:
                f(self._server)

        self._connect_peers = connect_peers
        self._auth_connect_peers = auth_connect_peers
        self._server_listen_ports = server_listen_ports

        self._api = api
        self._task = None
        self._is_stopping = False
        self._stopped_by_rpc = False

        self._on_connect_callback = on_connect_callback
        self._start_callback = start_callback
        self._stop_callback = stop_callback
        self._await_closed_callback = await_closed_callback
        self._advertised_port = advertised_port
        self._server_sockets: List = []

    def start(self):
        if self._task is not None:
            return

        async def _run():
            if self._start_callback:
                await self._start_callback()

            self._rpc_task = None
            self._rpc_close_task = None
            if self._rpc_info:
                rpc_api, rpc_port = self._rpc_info

                self._rpc_task = asyncio.create_task(
                    start_rpc_server(
                        rpc_api(self._api),
                        self.self_hostname,
                        self.daemon_port,
                        rpc_port,
                        self.stop,
                    ))

            self._reconnect_tasks = [
                start_reconnect_task(self._server, _, self._log,
                                     self._auth_connect_peers)
                for _ in self._connect_peers
            ]
            self._server_sockets = [
                await start_server(self._server, self._on_connect_callback)
                for _ in self._server_listen_ports
            ]

            signal.signal(signal.SIGINT, global_signal_handler)
            signal.signal(signal.SIGTERM, global_signal_handler)
            if platform == "win32" or platform == "cygwin":
                # pylint: disable=E1101
                signal.signal(signal.SIGBREAK,
                              global_signal_handler)  # type: ignore

        self._task = asyncio.create_task(_run())

    async def run(self):
        self.start()
        await self._task
        while not stopped_by_signal and not self._is_stopping:
            await asyncio.sleep(1)

        self.stop()
        await self.wait_closed()
        return 0

    def stop(self):
        if not self._is_stopping:
            self._is_stopping = True
            self._log.info("Closing server sockets")
            for _ in self._server_sockets:
                _.close()
            self._log.info("Cancelling reconnect task")
            for _ in self._reconnect_tasks:
                _.cancel()
            self._log.info("Closing connections")
            self._server.close_all()
            self._api._shut_down = True

            self._log.info("Calling service stop callback")
            if self._stop_callback:
                self._stop_callback()

            if self._rpc_task:
                self._log.info("Closing RPC server")

                async def close_rpc_server():
                    await (await self._rpc_task)()

                self._rpc_close_task = asyncio.create_task(close_rpc_server())

    async def wait_closed(self):
        self._log.info("Waiting for socket to be closed (if opened)")
        for _ in self._server_sockets:
            await _.wait_closed()

        self._log.info("Waiting for ChiaServer to be closed")
        await self._server.await_closed()

        if self._rpc_close_task:
            self._log.info("Waiting for RPC server")
            await self._rpc_close_task
            self._log.info("Closed RPC server")

        if self._await_closed_callback:
            self._log.info("Waiting for service _await_closed callback")
            await self._await_closed_callback()
        self._log.info(
            f"Service {self._service_name} at port {self._advertised_port} fully closed"
        )