Exemple #1
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim)
    elif model_name == "cnn" and args.dataset == "femnist":
        logging.info("CNN + FederatedEMNIST")
        model = CNN_DropOut(False)
    elif model_name == "resnet18_gn" and args.dataset == "fed_cifar100":
        logging.info("ResNet18_GN + Federated_CIFAR100")
        model = resnet18()
    elif model_name == "rnn" and args.dataset == "shakespeare":
        logging.info("RNN + shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "rnn" and args.dataset == "fed_shakespeare":
        logging.info("RNN + fed_shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "lr" and args.dataset == "stackoverflow_lr":
        logging.info("lr + stackoverflow_lr")
        model = LogisticRegression(10000, output_dim)
    elif model_name == "rnn" and args.dataset == "stackoverflow_nwp":
        logging.info("RNN + stackoverflow_nwp")
        model = RNN_StackOverFlow()
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    return model
Exemple #2
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" %
                 (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset in ["mnist", "fmnist", "emnist"]:
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim, flatten=True)
    elif model_name == "cnn" and args.dataset in ["mnist", "fmnist", "emnist"]:
        if args.dataset in ["mnist", "fmnist"]:
            logging.info("CNN + MNIST")
            model = CNN_DropOut(True)
        elif args.dataset == "emnist":
            logging.info("CNN + MNIST")
            model = CNN_DropOut(only_digits=47)
    elif model_name == "cnn" and args.dataset in ["har", "har_subject"]:
        logging.info("CNN + HAR")
        model = HAR_CNN(data_size=(9, 128), n_classes=6)
    elif model_name == "cnn" and args.dataset == "femnist":
        logging.info("CNN + FederatedEMNIST")
        model = CNN_DropOut(False)
    elif model_name == "cnn" and args.dataset == "cifar10":
        logging.info("CNN + CIFAR10")
        model = CNNCifar()
    elif model_name == "purchasemlp":
        if args.dataset == "purchase100":
            model = PurchaseMLP(input_dim=600, n_classes=100)
    elif model_name == "texasmlp":
        if args.dataset == "texas100":
            model = TexasMLP(input_dim=6169, n_classes=100)
    elif model_name == 'lr' and args.dataset == "adult":
        model = LogisticRegression(105, 2, flatten=False)
    elif model_name == "resnet18_gn" and args.dataset == "fed_cifar100":
        logging.info("ResNet18_GN + Federated_CIFAR100")
        model = resnet18()
    elif model_name == "rnn" and args.dataset == "shakespeare":
        logging.info("RNN + shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "rnn" and args.dataset == "fed_shakespeare":
        logging.info("RNN + fed_shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "lr" and args.dataset == "stackoverflow_lr":
        logging.info("lr + stackoverflow_lr")
        model = LogisticRegression(10000, output_dim)
    elif model_name == "rnn" and args.dataset == "stackoverflow_nwp":
        logging.info("RNN + stackoverflow_nwp")
        model = RNN_StackOverFlow()
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "vgg11":
        model = VGG("VGG11")
    elif model_name == "resnet20":
        if args.dataset == "cifar10":
            model = resnet20_cifar(num_classes=10)
        elif args.dataset == "chmnist":
            model = resnet20_cifar(num_classes=8)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    return model
Exemple #3
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" %
                 (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim)
    elif model_name == "cnn" and args.dataset == "femnist":
        logging.info("CNN + FederatedEMNIST")
        model = CNN_DropOut(False)
    elif model_name == "resnet18_gn" and args.dataset == "fed_cifar100":
        logging.info("ResNet18_GN + Federated_CIFAR100")
        model = resnet18()
    elif model_name == "rnn" and args.dataset == "shakespeare":
        logging.info("RNN + shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "rnn" and args.dataset == "fed_shakespeare":
        logging.info("RNN + fed_shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "lr" and args.dataset == "stackoverflow_lr":
        logging.info("lr + stackoverflow_lr")
        model = LogisticRegression(10004, output_dim)
    elif model_name == "rnn" and args.dataset == "stackoverflow_nwp":
        logging.info("CNN + stackoverflow_nwp")
        model = RNN_StackOverFlow()
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    # TODO
    elif model_name == 'mobilenet_v3':
        '''model_mode \in {LARGE: 5.15M, SMALL: 2.94M}'''
        model = MobileNetV3(model_mode='LARGE', num_classes=output_dim)
    elif model_name == 'efficientnet':
        # model = EfficientNet()
        efficientnet_dict = {
            # Coefficients:   width,depth,res,dropout
            'efficientnet-b0': (1.0, 1.0, 224, 0.2),
            'efficientnet-b1': (1.0, 1.1, 240, 0.2),
            'efficientnet-b2': (1.1, 1.2, 260, 0.3),
            'efficientnet-b3': (1.2, 1.4, 300, 0.3),
            'efficientnet-b4': (1.4, 1.8, 380, 0.4),
            'efficientnet-b5': (1.6, 2.2, 456, 0.4),
            'efficientnet-b6': (1.8, 2.6, 528, 0.5),
            'efficientnet-b7': (2.0, 3.1, 600, 0.5),
            'efficientnet-b8': (2.2, 3.6, 672, 0.5),
            'efficientnet-l2': (4.3, 5.3, 800, 0.5),
        }
        # default is 'efficientnet-b0'
        model = EfficientNet.from_name(model_name='efficientnet-b0',
                                       num_classes=output_dim)
        # model = EfficientNet.from_pretrained(model_name='efficientnet-b0')
    else:
        raise NotImplementedError

    return model
Exemple #4
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        model = LogisticRegression(28 * 28, output_dim)
        args.client_optimizer = "sgd"
    elif model_name == "rnn" and args.dataset == "shakespeare":
        model = RNN_OriginalFedAvg(28 * 28, output_dim)
        args.client_optimizer = "sgd"
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    return model
Exemple #5
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" %
                 (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset == "mnist":
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim)
    elif model_name == "rnn" and args.dataset == "shakespeare":
        logging.info("RNN + shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "cnn" and args.dataset == "femnist":
        logging.info("CNN + FederatedEMNIST")
        model = CNN_DropOut(False)
    elif model_name == "resnet18_gn" and args.dataset == "fed_cifar100":
        logging.info("ResNet18_GN + Federated_CIFAR100")
        model = resnet18()
    elif model_name == "rnn" and args.dataset == "fed_shakespeare":
        logging.info("RNN + fed_shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "lr" and args.dataset == "stackoverflow_lr":
        logging.info("lr + stackoverflow_lr")
        model = LogisticRegression(10004, output_dim)
    elif model_name == "rnn" and args.dataset == "stackoverflow_nwp":
        logging.info("CNN + stackoverflow_nwp")
        model = RNN_StackOverFlow()
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    elif model_name == "mobilenet_v2":
        model = models.mobilenet_v2()
    # TODO
    elif model_name == 'mobilenet_v3':
        '''model_mode \in {LARGE: 5.15M, SMALL: 2.94M}'''
        model = MobileNetV3(model_mode='LARGE')
    elif model_name == 'efficientnet':
        model = EfficientNet()

    return model
Exemple #6
0
def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" %
                 (model_name, output_dim))
    model = None
    if model_name == "lr" and args.dataset in ["mnist", "fmnist", "emnist"]:
        logging.info("LogisticRegression + MNIST")
        model = LogisticRegression(28 * 28, output_dim, flatten=True)
    elif model_name == "cnn" and args.dataset in ["mnist", "fmnist", "emnist"]:
        if args.dataset in ["mnist", "fmnist"]:
            logging.info("CNN + MNIST")
            model = CNN_DropOut(True)
        elif args.dataset == "emnist":
            logging.info("CNN + MNIST")
            model = CNN_DropOut(only_digits=47)
    elif model_name == "cnn" and args.dataset == "har":
        logging.info("CNN + HAR")
        # model = init_specific_model("Cnn1", data_size=(9, 128), num_classes=6)
        model = HAR_CNN(data_size=(9, 128), n_classes=6)
    elif model_name == "cnn" and args.dataset == "femnist":
        logging.info("CNN + FederatedEMNIST")
        model = CNN_DropOut(False)
    elif model_name == "purchasemlp":
        if args.dataset == "purchase100":
            model = PurchaseMLP(input_dim=600, n_classes=100)
    elif model_name == "texasmlp":
        if args.dataset == "texas100":
            model = TexasMLP(input_dim=6169, n_classes=100)
    elif model_name == 'lr' and args.dataset == "adult":
        model = LogisticRegression(105, 2, flatten=False)
    elif model_name == "resnet18_gn" and args.dataset == "fed_cifar100":
        logging.info("ResNet18_GN + Federated_CIFAR100")
        model = resnet18()
    elif model_name == "rnn" and args.dataset == "shakespeare":
        logging.info("RNN + shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "rnn" and args.dataset == "fed_shakespeare":
        logging.info("RNN + fed_shakespeare")
        model = RNN_OriginalFedAvg()
    elif model_name == "lr" and args.dataset == "stackoverflow_lr":
        logging.info("lr + stackoverflow_lr")
        model = LogisticRegression(10004, output_dim)
    elif model_name == "rnn" and args.dataset == "stackoverflow_nwp":
        logging.info("CNN + stackoverflow_nwp")
        model = RNN_StackOverFlow()
    elif model_name == "resnet56":
        model = resnet56(class_num=output_dim)
    elif model_name == "vgg11":
        model = VGG("VGG11")
    elif model_name == "resnet20":
        if args.dataset == "cifar10":
            model = resnet20_cifar(num_classes=10)
        elif args.dataset == "chmnist":
            model = resnet20_cifar(num_classes=8)
    elif model_name == "mobilenet":
        model = mobilenet(class_num=output_dim)
    elif model_name == 'mobilenet_v3':
        '''model_mode \in {LARGE: 5.15M, SMALL: 2.94M}'''
        model = MobileNetV3(model_mode='LARGE', num_classes=output_dim)
    elif model_name == 'efficientnet':
        # model = EfficientNet()
        efficientnet_dict = {
            # Coefficients:   width,depth,res,dropout
            'efficientnet-b0': (1.0, 1.0, 224, 0.2),
            'efficientnet-b1': (1.0, 1.1, 240, 0.2),
            'efficientnet-b2': (1.1, 1.2, 260, 0.3),
            'efficientnet-b3': (1.2, 1.4, 300, 0.3),
            'efficientnet-b4': (1.4, 1.8, 380, 0.4),
            'efficientnet-b5': (1.6, 2.2, 456, 0.4),
            'efficientnet-b6': (1.8, 2.6, 528, 0.5),
            'efficientnet-b7': (2.0, 3.1, 600, 0.5),
            'efficientnet-b8': (2.2, 3.6, 672, 0.5),
            'efficientnet-l2': (4.3, 5.3, 800, 0.5),
        }
        # default is 'efficientnet-b0'
        model = EfficientNet.from_name(model_name='efficientnet-b0',
                                       num_classes=output_dim)

    return model
Exemple #7
0
        data_loader = load_partition_data_distributed_cinic10
    else:
        data_loader = load_partition_data_distributed_cifar10

    train_data_num, train_data_global, \
    test_data_global, local_data_num, \
    train_data_local, test_data_local, class_num = data_loader(process_id, args.dataset, args.data_dir,
                                                               args.partition_method, args.partition_alpha,
                                                               args.client_number, args.batch_size)

    # create the model
    model = None
    split_layer = 1
    if args.model == "mobilenet":
        model = mobilenet(class_num=class_num)
    elif args.model == "resnet56":
        model = resnet56(class_num=class_num)

    fc_features = model.fc.in_features
    model.fc = nn.Sequential(nn.Flatten(), nn.Linear(fc_features, class_num))
    # Split The model
    client_model = nn.Sequential(
        *nn.ModuleList(model.children())[:split_layer])
    server_model = nn.Sequential(
        *nn.ModuleList(model.children())[split_layer:])

    SplitNN_distributed(process_id, worker_number, device, comm, client_model,
                        server_model, train_data_num, train_data_global,
                        test_data_global, local_data_num, train_data_local,
                        test_data_local, args)