Ejemplo n.º 1
0
def create_model(args, model_name, output_dim):
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        model = LogisticRegression(28 * 28, output_dim)
        args.client_optimizer = 'sgd'
    elif model_name == "rnn" and args.dataset == "shakespeare":
        model = RNN_OriginalFedAvg(28 * 28, output_dim)
        args.client_optimizer = 'adam'
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    return model
Ejemplo n.º 2
0
    elif args.dataset == "cinic10":
        data_loader = load_partition_data_distributed_cinic10
    else:
        data_loader = load_partition_data_distributed_cifar10

    train_data_num, train_data_global,\
    test_data_global, local_data_num, \
    train_data_local, test_data_local, class_num = data_loader(process_id, args.dataset, args.data_dir,
                                                               args.partition_method, args.partition_alpha,
                                                               args.client_number, args.batch_size)

    # create the model
    model = None
    split_layer = 1
    if args.model == "mobilenet":
        model = mobilenet(class_num=class_num)
    elif args.model == "resnet56":
        model = resnet56(class_num=class_num)

    fc_features = model.fc.in_features
    model.fc = nn.Sequential(nn.Flatten(), nn.Linear(fc_features, class_num))
    #Split The model
    client_model = nn.Sequential(
        *nn.ModuleList(model.children())[:split_layer])
    server_model = nn.Sequential(
        *nn.ModuleList(model.children())[split_layer:])

    SplitNN_distributed(process_id, worker_number, device, comm, client_model,
                        server_model, train_data_num, train_data_global,
                        test_data_global, local_data_num, train_data_local,
                        test_data_local, args)