Esempio n. 1
0
def test_client_anomaly_detection():
    HID_DIM = 16

    experts = {}
    for i in range(4):
        expert = layers.name_to_block['ffn'](HID_DIM)
        experts[f'expert.{i}'] = hivemind.ExpertBackend(
            name=f'expert.{i}',
            expert=expert,
            optimizer=torch.optim.Adam(expert.parameters()),
            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM), ),
            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
            max_batch_size=16,
        )

    experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')

    dht = hivemind.DHT(start=True)
    server = hivemind.Server(dht, experts, num_connection_handlers=1)
    server.start()
    try:
        server.ready.wait()

        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16,
                                               grid_size=(3, ),
                                               dht=dht,
                                               k_best=3,
                                               uid_prefix='expert.',
                                               detect_anomalies=True)

        input = torch.randn(1, 16)
        input[0, 0] = float('nan')

        with pytest.raises(ValueError):
            dmoe(input)

        input[0, 0] = 0
        output = dmoe(input)

        inf_loss = float('inf') * output.sum()
        with pytest.raises(ValueError):
            inf_loss.backward()

        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16,
                                               grid_size=(4, ),
                                               dht=dht,
                                               k_best=4,
                                               uid_prefix='expert.',
                                               detect_anomalies=True)
        output = dmoe(input)
        assert output.isfinite().all()

    finally:
        server.shutdown()
Esempio n. 2
0
def benchmark_throughput(num_experts=16,
                         num_handlers=None,
                         num_clients=128,
                         num_batches_per_client=16,
                         expert_cls='ffn',
                         hid_dim=1024,
                         batch_size=2048,
                         max_batch_size=None,
                         backprop=True,
                         device=None,
                         port=None):
    assert not hasattr(torch.cuda, 'is_initialized') or not torch.cuda.is_initialized() \
           or torch.device(device) == torch.device('cpu')
    assert expert_cls in layers.name_to_block
    port = port or find_open_port()
    max_batch_size = max_batch_size or batch_size * 4
    num_handlers = max(1, num_handlers or num_clients // 2)
    benchmarking_failed = mp.Event()
    can_start = mp.Event()
    timestamps = dict(started=time.perf_counter())

    try:
        # start clients and await server
        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
        clients = [
            mp.Process(target=client_process,
                       name=f'client_process-{i}',
                       args=(can_start, benchmarking_failed, port, num_experts,
                             batch_size, hid_dim, num_batches_per_client,
                             backprop)) for i in range(num_clients)
        ]

        for client in clients:
            client.daemon = True
            client.start()

        timestamps['launched_clients'] = timestamps[
            'began_launching_server'] = time.perf_counter()

        # start server
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        experts = {}
        for i in range(num_experts):
            expert = torch.jit.script(
                layers.name_to_block[expert_cls](hid_dim))
            experts[f'expert{i}'] = hivemind.ExpertBackend(
                name=f'expert{i}',
                expert=expert,
                opt=torch.optim.Adam(expert.parameters()),
                args_schema=(hivemind.BatchTensorProto(hid_dim), ),
                outputs_schema=hivemind.BatchTensorProto(hid_dim),
                max_batch_size=max_batch_size,
            )
        timestamps['created_experts'] = time.perf_counter()
        server = hivemind.Server(None,
                                 experts,
                                 port=port,
                                 conn_handler_processes=num_handlers,
                                 device=device)
        server.start()
        server.ready.wait()
        timestamps['server_ready'] = time.perf_counter()
        can_start.set()

        for client in clients:
            client.join()
        timestamps['clients_finished'] = time.perf_counter()
    except BaseException as e:
        benchmarking_failed.set()
        raise e
    finally:
        for client in clients:
            if client.is_alive():
                client.terminate()
        server.shutdown()
        timestamps['server_shutdown_finished'] = time.perf_counter()
        server.join()

    sys.stdout.flush()
    sys.stderr.flush()
    time_between = lambda key1, key2: \
        abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
    total_examples = batch_size * num_clients * num_batches_per_client

    print('\n' * 3)
    print("Benchmark finished, status:".format(
        ["Success", "Failure"][benchmarking_failed.is_set()]))
    print(
        "Server parameters: num_experts={}, num_handlers={}, max_batch_size={}, expert_cls={}, hid_dim={}, device={}"
        .format(num_experts, num_handlers, max_batch_size, expert_cls, hid_dim,
                device))
    print(
        "Client parameters: num_clients={}, num_batches_per_client={}, batch_size={}, backprop={}"
        .format(num_clients, num_batches_per_client, batch_size, backprop))

    startup_time = time_between('began_launching_server', 'server_ready')
    experts_time = time_between('began_launching_server', 'created_experts')
    networking_time = time_between('created_experts', 'server_ready')
    process_examples_time = time_between('server_ready', 'clients_finished')
    overall_time = time_between('started', 'server_shutdown_finished')

    stage = 'forward + backward' if backprop else 'forward'

    print("Results: ")
    print("\tServer startup took {} s. ({} s. experts + {} s. networking)".
          format(startup_time, experts_time, networking_time, '.3f'))
    print("\tProcessed {} examples in {}".format(
        total_examples, time_betweenprocess_examples_time, '.3f'))
    print("\tThroughput for {} passes: {} samples / s.".format(
        stage, total_examples / process_examples_time, '.3f'))
    print("\tBenchmarking took {} s.".format(overall_time, '.3f'))

    if benchmarking_failed.is_set():
        print(
            "Note: benchmark code failed, timing/memory results only indicate time till failure!"
        )
    print_device_info(device)
    print(flush=True)

    assert not benchmarking_failed.is_set()
Esempio n. 3
0
def make_dummy_server(host='0.0.0.0',
                      port=None,
                      num_experts=1,
                      expert_cls='ffn',
                      hidden_dim=1024,
                      num_handlers=None,
                      expert_prefix='expert',
                      expert_offset=0,
                      max_batch_size=16384,
                      device=None,
                      no_optimizer=False,
                      no_dht=False,
                      initial_peers=(),
                      dht_port=None,
                      root_port=None,
                      verbose=True,
                      start=False,
                      UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER,
                      **kwargs) -> hivemind.Server:
    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
    if verbose and len(kwargs) != 0:
        print("Ignored kwargs:", kwargs)
    assert expert_cls in name_to_block
    num_handlers = num_handlers if num_handlers is not None else num_experts * 8
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

    # initialize dht
    dht = None
    if not no_dht:
        if not len(initial_peers):
            print(
                "No initial peers provided. Starting additional dht as an initial peer."
            )
            dht_root = hivemind.DHTNode(*initial_peers,
                                        port=root_port
                                        or hivemind.find_open_port(),
                                        start=True)
            print(f"Initializing DHT with port {dht_root.port}")
            initial_peers = (('localhost', dht_root.port), )
        else:
            print("Bootstrapping dht with peers:", initial_peers)
            if root_port is not None:
                print(
                    f"Warning: root_port={root_port} will not be used since we already have peers."
                )

        dht = hivemind.DHTNode(*initial_peers,
                               port=dht_port or hivemind.find_open_port(),
                               start=True)
        if verbose:
            print(f"Running dht node on port {dht.port}")

    # initialize experts
    experts = {}
    for i in range(num_experts):
        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
        opt = torch.optim.SGD(expert.parameters(),
                              0.0) if no_optimizer else torch.optim.Adam(
                                  expert.parameters())
        expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
        experts[expert_uid] = hivemind.ExpertBackend(
            name=expert_uid,
            expert=expert,
            opt=opt,
            args_schema=(hivemind.BatchTensorProto(hidden_dim), ),
            outputs_schema=hivemind.BatchTensorProto(hidden_dim),
            max_batch_size=max_batch_size,
        )
    # actually start server
    server = hivemind.Server(dht,
                             experts,
                             addr=host,
                             port=port or hivemind.find_open_port(),
                             conn_handler_processes=num_handlers,
                             device=device)

    if start:
        server.run_in_background(await_ready=True)
        if verbose:
            print(f"Server started at {server.addr}:{server.port}")
            print(
                f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}"
            )
    return server