Example #1
0
    if args.dataset == "cifar10":
        data_loader = load_partition_data_cifar10
    elif args.dataset == "cifar100":
        data_loader = load_partition_data_cifar100
    elif args.dataset == "cinic10":
        data_loader = load_partition_data_cinic10
    else:
        data_loader = load_partition_data_cifar10
    train_data_num, test_data_num, train_data_global, test_data_global, \
    data_local_num_dict, train_data_local_dict, test_data_local_dict, \
    class_num = data_loader(args.dataset, args.data_dir, args.partition_method,
                                            args.partition_alpha, args.client_number, args.batch_size)

    dataset = [
        train_data_num, test_data_num, train_data_global, test_data_global,
        data_local_num_dict, train_data_local_dict, test_data_local_dict,
        class_num
    ]

    # create model
    # create the model
    model = None
    if args.model == "resnet56":
        model = resnet56(class_num)
    elif args.model == "mobilenet":
        model = mobilenet(class_num=class_num)

    trainer = FedAvgTrainer(dataset, model, device, args)
    trainer.train()
    trainer.global_test()
Example #2
0
    args = parser.parse_args()
    logger.info(args)
    device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    logger.info(device)

    wandb.init(
        project="fedml",
        name="FedAVG-r" + str(args.comm_round) + "-e" + str(args.epochs) + "-lr" + str(args.lr),
        config=args
    )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    # load data
    dataset = load_data(args, args.dataset)

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])
    logging.info(model)

    trainer = FedAvgTrainer(dataset, model, device, args)
    trainer.train()
Example #3
0
    # torch.manual_seed(seed)
    torch.set_printoptions(precision=10)

    # data
    logger.info("Partitioning data")
    # input:
    # output:
    X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
        args.dataset,
        args_datadir,
        args_logdir,
        args.partition,
        args.client_number,
        args_alpha,
        args=args)
    train_dl_global, test_dl_global = get_dataloader(args.dataset,
                                                     args_datadir,
                                                     args.batch_size, 32)

    n_classes = len(np.unique(y_train))
    print("n_classes = " + str(n_classes))
    print("traindata_cls_counts = " + str(traindata_cls_counts))
    print("train_dl_global number = " + str(len(train_dl_global)))
    print("test_dl_global number = " + str(len(test_dl_global)))

    trainer = FedAvgTrainer(net_dataidx_map, train_dl_global, test_dl_global,
                            device, args, n_classes, logger, switch_wandb)
    trainer.train()
    trainer.global_test()