def get_local_data(world_size, rank, batch_size):
    if IID == True:
        if DATA_SET == 'Mnist':
            train_loader = Mnist(rank, batch_size).get_train_data()
        if DATA_SET == 'Cifar10':
            train_loader = Cifar10(rank, batch_size).get_train_data()
        if DATA_SET == 'KWS':
            train_loader = KWS(rank, batch_size).get_train_data()
    else:
        if DATA_SET == 'Mnist':
            train_loader = Mnist_noniid(batch_size, world_size).get_train_data(rank)
        if DATA_SET == 'Cifar10':
            train_loader = Cifar10_noniid(batch_size, world_size).get_train_data(rank)
        if DATA_SET =='KWS':
            train_loader = KWS_noniid(batch_size, world_size).get_train_data(rank)
    return train_loader
예제 #2
0
def get_testset(rank):
    if IID == True:
        if DATA_SET == 'Mnist':
            test_loader = Mnist(rank).get_test_data()
        if DATA_SET == 'Cifar10':
            test_loader = Cifar10(rank).get_test_data()
        if DATA_SET == 'KWS':
            test_loader = KWS(rank).get_test_data()
    else:
        if DATA_SET == 'Mnist':
            test_loader = Mnist_noniid().get_test_data()
        if DATA_SET == 'Cifar10':
            test_loader = Cifar10_noniid().get_test_data()
        if DATA_SET == 'KWS':
            test_loader = KWS_noniid().get_test_data()
    return test_loader
예제 #3
0
파일: multi_slave.py 프로젝트: holoword/fl
def run(size, rank):


    modell = model.CNN()
    #modell = model.AlexNet()

    optimizer = torch.optim.Adam(modell.parameters(), lr=LR)
    loss_func = torch.nn.CrossEntropyLoss()



    if(IID == True):
        train_loader = Mnist().get_train_data()
        test_data = Mnist().get_test_data()
    else:
        if(rank > 0):
            if(rank == 1):
                train_loader = Mnist_noniid().get_train_data1()
                test_data = Mnist_noniid().get_test_data1()
            if(rank == 2):
                train_loader = Mnist_noniid().get_train_data2()
                test_data = Mnist_noniid().get_test_data2()
            if(rank == 3):
                train_loader = Mnist_noniid().get_train_data3()
                test_data = Mnist_noniid().get_test_data3()
            if(rank == 4):
                train_loader = Mnist_noniid().get_train_data4()
                test_data = Mnist_noniid().get_test_data4()
            if(rank == 5):
                train_loader = Mnist_noniid().get_train_data5()
                test_data = Mnist_noniid().get_test_data5()

    #size = dist.get_world_size()
    #rank = dist.get_rank()

    #train_loader = Mnist().get_train_data()
    #test_data = Mnist().get_test_data()

    for step, (b_x, b_y) in enumerate(test_data):
        test_x = b_x
        test_y = b_y

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    for epoch in range(MAX_EPOCH):

        modell = get_new_model(modell, group)
        #current_model = copy.deepcopy(modell)

        test_output, last_layer = modell(test_x)
        pred_y = torch.max(test_output, 1)[1].data.numpy()
        accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))


        for step, (b_x, b_y) in enumerate(train_loader):

            #modell = get_new_model(modell)
            #current_model = copy.deepcopy(modell)

            output = modell(b_x)[0]
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()   
            optimizer.step()

        for param in modell.parameters():
            dist.reduce(param.data, dst=0, op=dist.reduce_op.SUM, group=group)

        f = open('./test.txt', 'a')
        print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy, file=f)
        print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy)
        f.close()
예제 #4
0
def run(size, rank):
    modell = model.CNN()
    # modell = model.AlexNet()

    optimizer = torch.optim.Adam(modell.parameters(), lr=LR)
    loss_func = torch.nn.CrossEntropyLoss()

    # size = dist.get_world_size()
    # rank = dist.get_rank()

    if (IID == True):
        train_loader = Mnist().get_train_data()
        test_data = Mnist().get_test_data()
        test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
            torch.FloatTensor
        ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
        test_y = test_data.test_labels
    else:
        if (rank > 0):
            if (rank == 1):
                train_loader = Mnist_noniid().get_train_data1()
                test_data = Mnist_noniid().get_test_data1()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 2):
                train_loader = Mnist_noniid().get_train_data2()
                test_data = Mnist_noniid().get_test_data2()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 3):
                train_loader = Mnist_noniid().get_train_data3()
                test_data = Mnist_noniid().get_test_data3()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 4):
                train_loader = Mnist_noniid().get_train_data4()
                test_data = Mnist_noniid().get_test_data4()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 5):
                train_loader = Mnist_noniid().get_train_data5()
                test_data = Mnist_noniid().get_test_data5()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels

    # test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
    #     torch.FloatTensor) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
    # test_y = test_data.test_labels

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    for epoch in range(MAX_EPOCH):

        modell = get_new_model(modell)
        # current_model = copy.deepcopy(modell)

        for step, (b_x, b_y) in enumerate(train_loader):
            # modell = get_new_model(modell)
            # current_model = copy.deepcopy(modell)

            output = modell(b_x)[0]
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # new_model = copy.deepcopy(modell)

        # for param1, param2 in zip( current_model.parameters(), new_model.parameters() ):
        # dist.reduce(param2.data-param1.data, dst=0, op=dist.reduce_op.SUM, group=group)

        for param in modell.parameters():
            dist.reduce(param, dst=0, op=dist.reduce_op.SUM, group=group)

        test_output, last_layer = modell(test_x)
        pred_y = torch.max(test_output, 1)[1].data.numpy()
        accuracy = float(
            (pred_y == test_y.data.numpy()).astype(int).sum()) / float(
                test_y.size(0))
        print('Epoch: ', epoch, ' Rank: ', rank,
              '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy)