예제 #1
0
    download_and_unzip(args.dataset_url)

    # load device-specific data
    device_id = str(client_ID - 1)
    dataset = load_data_by_device(args, args.dataset, device_id)
    [
        train_data_num, test_data_num, train_data_global, test_data_global,
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict,
        class_num
    ] = dataset

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])
    model_trainer = MyModelTrainer(model)
    model_trainer.set_id(client_index)

    # start training
    trainer = FedAVGTrainer(client_index, train_data_local_dict,
                            train_data_local_num_dict, test_data_local_dict,
                            train_data_num, device, args, model_trainer)

    size = args.client_num_per_round + 1
    client_manager = FedAVGClientManager(args,
                                         trainer,
                                         rank=client_ID,
                                         size=size,
                                         backend="GRPC")
    client_manager.run()
    client_manager.start_training()
예제 #2
0
파일: app.py 프로젝트: qinyeli/FedML-Server
    # GPU 0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load data
    dataset = load_data(args, args.dataset)
    [
        train_data_num, test_data_num, train_data_global, test_data_global,
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict,
        class_num
    ] = dataset

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])
    model_trainer = MyModelTrainer(model)

    aggregator = FedAVGAggregator(train_data_global, test_data_global,
                                  train_data_num, train_data_local_dict,
                                  test_data_local_dict,
                                  train_data_local_num_dict,
                                  args.client_num_per_round, device, args,
                                  model_trainer)
    size = args.client_num_per_round + 1
    server_manager = FedAVGServerManager(args,
                                         aggregator,
                                         rank=0,
                                         size=size,
                                         backend="MQTT",
                                         is_preprocessed=args.is_preprocessed)
    server_manager.run()