Beispiel #1
0
def test_compute_expert_scores():
    try:
        dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
        moe = hivemind.client.moe.RemoteMixtureOfExperts(dht=dht,
                                                         in_features=1024,
                                                         grid_size=[40],
                                                         k_best=4,
                                                         k_min=1,
                                                         timeout_after_k_min=1,
                                                         uid_prefix='expert')
        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(
            4, 3, requires_grad=True)
        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
        batch_experts = [
            [
                hivemind.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}')
                for e in range(len(ii[b]))
            ] for b in range(len(ii))
        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
        logits = moe.compute_expert_scores([gx, gy], batch_experts)
        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(
        ), "compute_expert_scores didn't backprop"

        for b in range(len(ii)):
            for e in range(len(ii[b])):
                assert torch.allclose(
                    logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]
                ), "compute_expert_scores returned incorrect score"
    finally:
        dht.shutdown()
Beispiel #2
0
def test_empty_table():
    """ Test RPC methods with empty routing table """
    peer_port, peer_id, peer_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer_proc = mp.Process(target=run_protocol_listener,
                           args=(peer_port, peer_id, peer_started),
                           daemon=True)
    peer_proc.start(), peer_started.wait()
    test_success = mp.Event()

    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        protocol = loop.run_until_complete(
            DHTProtocol.create(DHTID.generate(),
                               bucket_size=20,
                               depth_modulo=5,
                               wait_timeout=5,
                               num_replicas=3,
                               listen=False))

        key, value, expiration = DHTID.generate(), [
            random.random(), {
                'ololo': 'pyshpysh'
            }
        ], get_dht_time() + 1e3

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        assert recv_value_bytes is None and recv_expiration is None and len(
            nodes_found) == 0
        assert all(
            loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer_port}', [key],
                                    [hivemind.MSGPackSerializer.dumps(value)],
                                    expiration))), "peer rejected store"

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
        assert len(nodes_found) == 0
        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"

        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
        assert loop.run_until_complete(
            protocol.call_ping(
                f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
        test_success.set()

    tester = mp.Process(target=_tester, daemon=True)
    tester.start()
    tester.join()
    assert test_success.is_set()
    peer_proc.terminate()
Beispiel #3
0
def test_empty_table():
    """ Test RPC methods with empty routing table """
    peer_port, peer_id, peer_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer_proc = mp.Process(target=run_protocol_listener,
                           args=(peer_port, peer_id, peer_started),
                           daemon=True)
    peer_proc.start(), peer_started.wait()

    loop = asyncio.get_event_loop()
    protocol = loop.run_until_complete(
        DHTProtocol.create(DHTID.generate(),
                           bucket_size=20,
                           depth_modulo=5,
                           wait_timeout=5,
                           num_replicas=3,
                           listen=False))

    key, value, expiration = DHTID.generate(), [
        random.random(), {
            'ololo': 'pyshpysh'
        }
    ], get_dht_time() + 1e3

    empty_item, nodes_found = loop.run_until_complete(
        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
    assert empty_item is None and len(nodes_found) == 0
    assert all(
        loop.run_until_complete(
            protocol.call_store(f'{LOCALHOST}:{peer_port}', [key],
                                [hivemind.MSGPackSerializer.dumps(value)],
                                expiration))), "peer rejected store"

    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
    assert len(nodes_found) == 0
    assert recv_value == value and recv_expiration == expiration

    assert loop.run_until_complete(
        protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
    assert loop.run_until_complete(
        protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
    peer_proc.terminate()
def benchmark_throughput(num_experts=16,
                         num_handlers=None,
                         num_clients=128,
                         num_batches_per_client=16,
                         expert_cls='ffn',
                         hid_dim=1024,
                         batch_size=2048,
                         max_batch_size=None,
                         backprop=True,
                         device=None,
                         port=None):
    assert not hasattr(torch.cuda, 'is_initialized') or not torch.cuda.is_initialized() \
           or torch.device(device) == torch.device('cpu')
    assert expert_cls in layers.name_to_block
    port = port or find_open_port()
    max_batch_size = max_batch_size or batch_size * 4
    num_handlers = max(1, num_handlers or num_clients // 2)
    benchmarking_failed = mp.Event()
    can_start = mp.Event()
    timestamps = dict(started=time.perf_counter())

    try:
        # start clients and await server
        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
        clients = [
            mp.Process(target=client_process,
                       name=f'client_process-{i}',
                       args=(can_start, benchmarking_failed, port, num_experts,
                             batch_size, hid_dim, num_batches_per_client,
                             backprop)) for i in range(num_clients)
        ]

        for client in clients:
            client.daemon = True
            client.start()

        timestamps['launched_clients'] = timestamps[
            'began_launching_server'] = time.perf_counter()

        # start server
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        experts = {}
        for i in range(num_experts):
            expert = torch.jit.script(
                layers.name_to_block[expert_cls](hid_dim))
            experts[f'expert{i}'] = hivemind.ExpertBackend(
                name=f'expert{i}',
                expert=expert,
                opt=torch.optim.Adam(expert.parameters()),
                args_schema=(hivemind.BatchTensorProto(hid_dim), ),
                outputs_schema=hivemind.BatchTensorProto(hid_dim),
                max_batch_size=max_batch_size,
            )
        timestamps['created_experts'] = time.perf_counter()
        server = hivemind.Server(None,
                                 experts,
                                 port=port,
                                 conn_handler_processes=num_handlers,
                                 device=device)
        server.start()
        server.ready.wait()
        timestamps['server_ready'] = time.perf_counter()
        can_start.set()

        for client in clients:
            client.join()
        timestamps['clients_finished'] = time.perf_counter()
    except BaseException as e:
        benchmarking_failed.set()
        raise e
    finally:
        for client in clients:
            if client.is_alive():
                client.terminate()
        server.shutdown()
        timestamps['server_shutdown_finished'] = time.perf_counter()
        server.join()

    sys.stdout.flush()
    sys.stderr.flush()
    time_between = lambda key1, key2: \
        abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
    total_examples = batch_size * num_clients * num_batches_per_client

    print('\n' * 3)
    print("Benchmark finished, status:".format(
        ["Success", "Failure"][benchmarking_failed.is_set()]))
    print(
        "Server parameters: num_experts={}, num_handlers={}, max_batch_size={}, expert_cls={}, hid_dim={}, device={}"
        .format(num_experts, num_handlers, max_batch_size, expert_cls, hid_dim,
                device))
    print(
        "Client parameters: num_clients={}, num_batches_per_client={}, batch_size={}, backprop={}"
        .format(num_clients, num_batches_per_client, batch_size, backprop))

    startup_time = time_between('began_launching_server', 'server_ready')
    experts_time = time_between('began_launching_server', 'created_experts')
    networking_time = time_between('created_experts', 'server_ready')
    process_examples_time = time_between('server_ready', 'clients_finished')
    overall_time = time_between('started', 'server_shutdown_finished')

    stage = 'forward + backward' if backprop else 'forward'

    print("Results: ")
    print("\tServer startup took {} s. ({} s. experts + {} s. networking)".
          format(startup_time, experts_time, networking_time, '.3f'))
    print("\tProcessed {} examples in {}".format(
        total_examples, time_betweenprocess_examples_time, '.3f'))
    print("\tThroughput for {} passes: {} samples / s.".format(
        stage, total_examples / process_examples_time, '.3f'))
    print("\tBenchmarking took {} s.".format(overall_time, '.3f'))

    if benchmarking_failed.is_set():
        print(
            "Note: benchmark code failed, timing/memory results only indicate time till failure!"
        )
    print_device_info(device)
    print(flush=True)

    assert not benchmarking_failed.is_set()
Beispiel #5
0
def test_dht_protocol():
    # create the first peer
    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer1_proc = mp.Process(target=run_protocol_listener,
                            args=(peer1_port, peer1_id, peer1_started),
                            daemon=True)
    peer1_proc.start(), peer1_started.wait()

    # create another peer that connects to the first peer
    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer2_proc = mp.Process(target=run_protocol_listener,
                            args=(peer2_port, peer2_id, peer2_started),
                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'},
                            daemon=True)
    peer2_proc.start(), peer2_started.wait()

    loop = asyncio.get_event_loop()
    for listen in [
            False, True
    ]:  # note: order matters, this test assumes that first run uses listen=False
        protocol = loop.run_until_complete(
            DHTProtocol.create(DHTID.generate(),
                               bucket_size=20,
                               depth_modulo=5,
                               wait_timeout=5,
                               num_replicas=3,
                               listen=listen))
        print(f"Self id={protocol.node_id}", flush=True)

        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id

        key, value, expiration = DHTID.generate(), [
            random.random(), {
                'ololo': 'pyshpysh'
            }
        ], get_dht_time() + 1e3
        store_ok = loop.run_until_complete(
            protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key],
                                [hivemind.MSGPackSerializer.dumps(value)],
                                expiration))
        assert all(store_ok), "DHT rejected a trivial store"

        # peer 1 must know about peer 2
        (recv_value_bytes,
         recv_expiration), nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"

        assert recv_value == value and recv_expiration == expiration, \
            f"call_find_value expected {value} (expires by {expiration}) " \
            f"but got {recv_value} (expires by {recv_expiration})"

        # peer 2 must know about peer 1, but not have a *random* nonexistent value
        dummy_key = DHTID.generate()
        empty_item, nodes_found_2 = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer2_port}',
                               [dummy_key]))[dummy_key]
        assert empty_item is None, "Non-existent keys shouldn't have values"
        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"

        # cause a non-response by querying a nonexistent peer
        dummy_port = hivemind.find_open_port()
        assert loop.run_until_complete(
            protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None

        # store/get a dictionary with sub-keys
        nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
        value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
        assert loop.run_until_complete(
            protocol.call_store(
                f'{LOCALHOST}:{peer1_port}',
                keys=[nested_key],
                values=[hivemind.MSGPackSerializer.dumps(value1)],
                expiration_time=[expiration],
                subkeys=[subkey1]))
        assert loop.run_until_complete(
            protocol.call_store(
                f'{LOCALHOST}:{peer1_port}',
                keys=[nested_key],
                values=[hivemind.MSGPackSerializer.dumps(value2)],
                expiration_time=[expiration + 5],
                subkeys=[subkey2]))
        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer1_port}',
                               [nested_key]))[nested_key]
        assert isinstance(recv_dict, DictionaryDHTValue)
        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1),
                                           expiration)
        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2),
                                           expiration + 5)

        assert LOCALHOST in loop.run_until_complete(
            protocol.get_outgoing_request_endpoint(
                f'{LOCALHOST}:{peer1_port}'))

        if listen:
            loop.run_until_complete(protocol.shutdown())

    peer1_proc.terminate()
    peer2_proc.terminate()
Beispiel #6
0
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        for listen in [
                False, True
        ]:  # note: order matters, this test assumes that first run uses listen=False
            protocol = loop.run_until_complete(
                DHTProtocol.create(DHTID.generate(),
                                   bucket_size=20,
                                   depth_modulo=5,
                                   wait_timeout=5,
                                   num_replicas=3,
                                   listen=listen))
            print(f"Self id={protocol.node_id}", flush=True)

            assert loop.run_until_complete(
                protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id

            key, value, expiration = DHTID.generate(), [
                random.random(), {
                    'ololo': 'pyshpysh'
                }
            ], get_dht_time() + 1e3
            store_ok = loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key],
                                    [hivemind.MSGPackSerializer.dumps(value)],
                                    expiration))
            assert all(store_ok), "DHT rejected a trivial store"

            # peer 1 must know about peer 2
            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"

            assert recv_value == value and recv_expiration == expiration, \
                f"call_find_value expected {value} (expires by {expiration}) " \
                f"but got {recv_value} (expires by {recv_expiration})"

            # peer 2 must know about peer 1, but not have a *random* nonexistent value
            dummy_key = DHTID.generate()
            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer2_port}',
                                   [dummy_key]))[dummy_key]
            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"

            # cause a non-response by querying a nonexistent peer
            dummy_port = hivemind.find_open_port()
            assert loop.run_until_complete(
                protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None

            if listen:
                loop.run_until_complete(protocol.shutdown())
            print("DHTProtocol test finished successfully!")
            test_success.set()
Beispiel #7
0
    def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
               **kwargs) -> Server:
        """
        Instantiate a server with several identical experts. See argparse comments below for details
        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
        :param num_experts: run this many identical experts
        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
           means "sample random experts between myprefix.0.0 and myprefix.255.255;
        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
        :param hidden_dim: main dimension for expert_cls
        :param num_handlers: server will use this many parallel processes to handle incoming requests
        :param max_batch_size: total num examples in the same batch will not exceed this value
        :param device: all experts will use this device in torch notation; default: cuda if available else cpu

        :param optim_cls: uses this optimizer to train all experts
        :param scheduler: if not `none`, the name of the expert LR scheduler
        :param num_warmup_steps: the number of warmup steps for LR schedule
        :param num_total_steps: the total number of steps for LR schedule
        :param clip_grad_norm: maximum gradient norm used for clipping

        :param no_dht: if specified, the server will not be attached to a dht
        :param initial_peers: a list of peers that will introduce this node to the dht,\
           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers

        :param dht_port:  DHT node will listen on this port, default = find open port
           You can then use this node as initial peer for subsequent servers.

        :param checkpoint_dir: directory to save and load expert checkpoints

        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
            hosted on this server. For a more fine-grained compression, start server in python and specify compression
            for each BatchTensorProto in ExpertBackend for the respective experts.

        :param start: if True, starts server right away and returns when server is ready for requests
        :param stats_report_interval: interval between two reports of batch processing performance statistics
        """
        if len(kwargs) != 0:
            logger.info("Ignored kwargs:", kwargs)
        assert expert_cls in name_to_block

        if no_dht:
            dht = None
        else:
            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")

        assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
                (num_experts is not None and expert_uids is None)), \
            "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"

        if expert_uids is None:
            if checkpoint_dir is not None:
                assert is_directory(checkpoint_dir)
                expert_uids = [child.name for child in checkpoint_dir.iterdir() if
                               (child / 'checkpoint_last.pt').exists()]
                total_experts_in_checkpoint = len(expert_uids)
                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")

                if total_experts_in_checkpoint > num_experts:
                    raise ValueError(
                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
            else:
                expert_uids = []

            uids_to_generate = num_experts - len(expert_uids)
            if uids_to_generate > 0:
                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))

        num_experts = len(expert_uids)
        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        sample_input = name_to_input[expert_cls](4, hidden_dim)
        if isinstance(sample_input, tuple):
            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
        else:
            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)

        scheduler = schedule_name_to_scheduler[scheduler]

        # initialize experts
        experts = {}
        for expert_uid in expert_uids:
            expert = name_to_block[expert_cls](hidden_dim)
            experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                         args_schema=args_schema,
                                                         outputs_schema=hivemind.BatchTensorDescriptor(
                                                             hidden_dim, compression=compression),
                                                         optimizer=optim_cls(expert.parameters()),
                                                         scheduler=scheduler,
                                                         num_warmup_steps=num_warmup_steps,
                                                         num_total_steps=num_total_steps,
                                                         clip_grad_norm=clip_grad_norm,
                                                         max_batch_size=max_batch_size)

        if checkpoint_dir is not None:
            load_experts(experts, checkpoint_dir)

        return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                   checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
Beispiel #8
0
def make_dummy_server(host='0.0.0.0',
                      port=None,
                      num_experts=1,
                      expert_cls='ffn',
                      hidden_dim=1024,
                      num_handlers=None,
                      expert_prefix='expert',
                      expert_offset=0,
                      max_batch_size=16384,
                      device=None,
                      no_optimizer=False,
                      no_dht=False,
                      initial_peers=(),
                      dht_port=None,
                      root_port=None,
                      verbose=True,
                      start=False,
                      UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER,
                      **kwargs) -> hivemind.Server:
    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
    if verbose and len(kwargs) != 0:
        print("Ignored kwargs:", kwargs)
    assert expert_cls in name_to_block
    num_handlers = num_handlers if num_handlers is not None else num_experts * 8
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

    # initialize dht
    dht = None
    if not no_dht:
        if not len(initial_peers):
            print(
                "No initial peers provided. Starting additional dht as an initial peer."
            )
            dht_root = hivemind.DHTNode(*initial_peers,
                                        port=root_port
                                        or hivemind.find_open_port(),
                                        start=True)
            print(f"Initializing DHT with port {dht_root.port}")
            initial_peers = (('localhost', dht_root.port), )
        else:
            print("Bootstrapping dht with peers:", initial_peers)
            if root_port is not None:
                print(
                    f"Warning: root_port={root_port} will not be used since we already have peers."
                )

        dht = hivemind.DHTNode(*initial_peers,
                               port=dht_port or hivemind.find_open_port(),
                               start=True)
        if verbose:
            print(f"Running dht node on port {dht.port}")

    # initialize experts
    experts = {}
    for i in range(num_experts):
        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
        opt = torch.optim.SGD(expert.parameters(),
                              0.0) if no_optimizer else torch.optim.Adam(
                                  expert.parameters())
        expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
        experts[expert_uid] = hivemind.ExpertBackend(
            name=expert_uid,
            expert=expert,
            opt=opt,
            args_schema=(hivemind.BatchTensorProto(hidden_dim), ),
            outputs_schema=hivemind.BatchTensorProto(hidden_dim),
            max_batch_size=max_batch_size,
        )
    # actually start server
    server = hivemind.Server(dht,
                             experts,
                             addr=host,
                             port=port or hivemind.find_open_port(),
                             conn_handler_processes=num_handlers,
                             device=device)

    if start:
        server.run_in_background(await_ready=True)
        if verbose:
            print(f"Server started at {server.addr}:{server.port}")
            print(
                f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}"
            )
    return server
Beispiel #9
0
    def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
        """
        Instantiate a server with several identical experts. See argparse comments below for details
        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
        :param num_experts: run this many identical experts
        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
           means "sample random experts between myprefix.0.0 and myprefix.255.255;
        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
        :param hidden_dim: main dimension for expert_cls
        :param num_handlers: server will use this many parallel processes to handle incoming requests
        :param max_batch_size: total num examples in the same batch will not exceed this value
        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
        :param optim_cls: uses this optimizer to train all experts
        :param no_dht: if specified, the server will not be attached to a dht
        :param initial_peers: a list of peers that will introduce this node to the dht,\
           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers

        :param dht_port:  DHT node will listen on this port, default = find open port
           You can then use this node as initial peer for subsequent servers.

        :param checkpoint_dir: directory to save expert checkpoints
        :param load_experts: whether to load expert checkpoints from checkpoint_dir

        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
            hosted on this server. For a more fine-grained compression, start server in python and specify compression
            for each BatchTensorProto in ExpertBackend for the respective experts.

        :param start: if True, starts server right away and returns when server is ready for requests
        """
        if len(kwargs) != 0:
            logger.info("Ignored kwargs:", kwargs)
        assert expert_cls in name_to_block

        if no_dht:
            dht = None
        else:
            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")

        if load_experts:
            assert dir_is_correct(checkpoint_dir)
            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
            if expert_uids:
                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
            else:
                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")

        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"

        # get expert uids if not loaded previously
        if expert_uids is None:
            assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
            logger.info(f"Generating expert uids from pattern {expert_pattern}")
            expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)

        num_experts = len(expert_uids)
        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        sample_input = name_to_input[expert_cls](4, hidden_dim)
        if isinstance(sample_input, tuple):
            args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
        else:
            args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)

        # initialize experts
        experts = {}
        for expert_uid in expert_uids:
            expert = name_to_block[expert_cls](hidden_dim)
            experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
                                                         args_schema=args_schema,
                                                         outputs_schema=hivemind.BatchTensorDescriptor(
                                                             hidden_dim, compression=compression),
                                                         opt=optim_cls(expert.parameters()),
                                                         max_batch_size=max_batch_size)

        if load_experts:
            load_weights(experts, checkpoint_dir)

        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                        start=start)
        return server