Beispiel #1
0
def init_client(args, device, comm, process_id, size, model, train_data_num, local_data_num, train_data_local):
    # trainer
    client_ID = process_id - 1
    trainer = FedAVGTrainer(client_ID, train_data_local, local_data_num, train_data_num, device, model, args)

    client_manager = FedAVGClientManager(args, comm, process_id, size, trainer)
    client_manager.run()
Beispiel #2
0
def init_client(args, device, comm, process_id, size, model, train_data_num, train_data_local_num_dict, train_data_local_dict):
    # trainer
    client_index = process_id - 1
    trainer = FedAVGTrainer(client_index, train_data_local_dict, train_data_local_num_dict, train_data_num, device, model, args)

    client_manager = FedAVGClientManager(args, trainer, comm, process_id, size)
    client_manager.run()
Beispiel #3
0
                                  4)

    # load data
    dataset = load_data(args, args.dataset)
    [
        train_data_num, test_data_num, train_data_global, test_data_global,
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict,
        class_num
    ] = dataset

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])

    client_index = client_ID - 1
    trainer = FedAVGTrainer(client_index, train_data_local_dict,
                            train_data_local_num_dict, train_data_num, device,
                            model, args)

    size = args.client_num_per_round + 1
    client_manager = FedAVGClientManager(args,
                                         trainer,
                                         rank=client_ID,
                                         size=size,
                                         backend="MQTT")
    client_manager.run()
    client_manager.start_training()

    time.sleep(100000)