if args.dataset == "cifar10": data_loader = load_partition_data_cifar10 elif args.dataset == "cifar100": data_loader = load_partition_data_cifar100 elif args.dataset == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_number, args.batch_size) dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] # create model # create the model model = None if args.model == "resnet56": model = resnet56(class_num) elif args.model == "mobilenet": model = mobilenet(class_num=class_num) trainer = FedAvgTrainer(dataset, model, device, args) trainer.train() trainer.global_test()
args = parser.parse_args() logger.info(args) device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") logger.info(device) wandb.init( project="fedml", name="FedAVG-r" + str(args.comm_round) + "-e" + str(args.epochs) + "-lr" + str(args.lr), config=args ) # Set the random seed. The np.random seed determines the dataset partition. # The torch_manual_seed determines the initial weight. # We fix these two, so that we can reproduce the result. random.seed(0) np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) # load data dataset = load_data(args, args.dataset) # create model. # Note if the model is DNN (e.g., ResNet), the training will be very slow. # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg) model = create_model(args, model_name=args.model, output_dim=dataset[7]) logging.info(model) trainer = FedAvgTrainer(dataset, model, device, args) trainer.train()
# torch.manual_seed(seed) torch.set_printoptions(precision=10) # data logger.info("Partitioning data") # input: # output: X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data( args.dataset, args_datadir, args_logdir, args.partition, args.client_number, args_alpha, args=args) train_dl_global, test_dl_global = get_dataloader(args.dataset, args_datadir, args.batch_size, 32) n_classes = len(np.unique(y_train)) print("n_classes = " + str(n_classes)) print("traindata_cls_counts = " + str(traindata_cls_counts)) print("train_dl_global number = " + str(len(train_dl_global))) print("test_dl_global number = " + str(len(test_dl_global))) trainer = FedAvgTrainer(net_dataidx_map, train_dl_global, test_dl_global, device, args, n_classes, logger, switch_wandb) trainer.train() trainer.global_test()