コード例 #1
0
ファイル: rnd_mnist.py プロジェクト: Jueun-Park/wfedavg-mnist
from module.rnd import RandomNetworkDistillation
from module.load_and_split_mnist_dataset import concat_data

if __name__ == "__main__":
    rnd = RandomNetworkDistillation(log_interval=1000,
                                    lr=1e-3,
                                    path="model/allenv/rnd_model/",
                                    verbose=1)
    train_tensor, valid_tensor = concat_data(list(range(0, 10)), mode="tensor")
    rnd.set_data(train_tensor, valid_tensor)
    rnd.learn(30)
    rnd.save()
コード例 #2
0
from module.rnd import RandomNetworkDistillation
from module.load_and_split_mnist_dataset import concat_data

NUM_WORKERS = 4
epochs = 100

for i in range(0, 8, 2):
    td, vd = concat_data(list(range(10))[i:i + 4], mode="tensor")
    rnd = RandomNetworkDistillation(log_interval=1000,
                                    lr=1e-3,
                                    use_cuda=False,
                                    verbose=1,
                                    log_tensorboard=True,
                                    path=f"model/subenv_{i}-{i+4}/rnd_model/")
    rnd.set_data(td, vd)
    rnd.learn(epochs)
    rnd.save()
コード例 #3
0
            layer_parameter.append(sub_model_parameters[i][key].numpy())
        # weighted average
        delta = np.average(layer_parameter, axis=0, weights=w)
        base_parameter_dict[key] = (
            1 - alpha) * base_parameter_dict[key] + alpha * delta


alpha = 0.5
base_idx = 1

if __name__ == "__main__":
    num_model_4_indices = [list(range(10))[i:i + 4] for i in range(0, 8, 2)]
    num_model_4_comments = [str(i) + "-" + str(i + 4) for i in range(0, 8, 2)]

    # load base subenv dataset
    base_train_ts, base_valid_ts = concat_data(num_model_4_indices[base_idx],
                                               mode="tensor")

    # get weight from base gan model
    gan_weights = []
    for data_com in num_model_4_comments:
        gan = GenerativeAdversarialNetwork(
            save_path=f"./model/subenv_{data_com}/gan")
        gan.load()
        w = gan.get_discriminator_output(base_valid_ts)
        gan_weights.append(w.item())
    print(gan_weights)
    gan_weights = softmax(gan_weights)
    print(gan_weights)

    # load base model parameter
    base_model = Net()
コード例 #4
0
parser.add_argument("--base-index", type=int)
args = parser.parse_args()
base_idx = args.base_index

use_cuda = True

if __name__ == "__main__":
    print(f"Base index: {base_idx}")
    # load base model parameter
    base_model = Net()
    base_model.load_state_dict(
        torch.load(f"model/subenv_{model_comments[base_idx]}/mnist_cnn.pt"))
    sub_model_parameters = [base_model.state_dict() for _ in range(num_models)]

    # client training
    for i, data_idx in enumerate(model_indices):
        print(f"data index: {data_idx}")
        td, vd = concat_data(data_idx, mode="dataset")
        tdl = DataLoader(td, batch_size=64, shuffle=True)
        vdl = DataLoader(vd, batch_size=64, shuffle=True)
        learner = Learner(tdl,
                          vdl,
                          lr=0.001,
                          log_interval=100,
                          use_cuda=use_cuda)
        learner.model.load_state_dict(sub_model_parameters[i])
        learner.learn(fed_learn_epochs)
        sub_model_parameters[i] = learner.model.state_dict()
        learner.save(f"./wfed_model_base{base_idx}/subenv_{model_comments[i]}")
        del learner
コード例 #5
0
    base_model.load_state_dict(
        torch.load(f"model/subenv_{model_comments[base_idx]}/mnist_cnn.pt"))

    sub_model_parameters = []
    for i in range(num_models):
        net = Net()
        net.load_state_dict(
            torch.load(
                f"./wfed_model_base{base_idx}/subenv_{model_comments[i]}/mnist_cnn.pt"
            ))
        sub_model_parameters.append(net.state_dict())
        del net

    # load base subenv dataset
    base_train_ds, base_valid_ds = concat_data(model_indices[base_idx],
                                               mode="dataset")

    test_losses = []
    accuracies = []
    labels = []
    aligned_model = Net()
    for i, w in enumerate(weights):
        # client model align
        learner = Learner(DataLoader(base_train_ds, batch_size=64),
                          DataLoader(base_valid_ds, batch_size=64),
                          use_cuda=use_cuda)
        base_parameter_dict = base_model.state_dict()
        model_align(w, base_parameter_dict, sub_model_parameters, alpha=alpha)
        aligned_model.load_state_dict(base_parameter_dict)

        # evaluate fedavg model
コード例 #6
0
    # plt.suptitle("Discriminator output in dynamic dataset")
    # plt.show()

    # ===

    subplot_id = 1
    for model_id in range(0, 8, 2):
        gan = GenerativeAdversarialNetwork(
            save_path=f"model/subenv_{model_id}-{model_id+4}/gan")
        gan.load()

        labels = []
        discriminator_output = []
        for data_id in range(0, 8, 2):

            td, vd = concat_data(list(range(10))[data_id:data_id + 4],
                                 mode="tensor")
            print(gan.get_discriminator_output(vd))
            labels.append(f"{data_id}-{data_id+3}")
            discriminator_output.append(gan.get_discriminator_output(vd))
        del td, vd
        plt.subplot(2, 2, subplot_id)
        plt.title(f"model label {model_id}-{model_id+3}")
        plt.ylim((-0.1, 1.1))
        plt.plot(labels, discriminator_output, "rs--")
        for a, b in zip(labels, discriminator_output):
            plt.text(a, b, str(f"{b.item():.2f}"))
        plt.xlabel("data label")
        plt.ylabel("discriminator output")
        subplot_id += 1
    del gan
コード例 #7
0
from torch.utils.data import DataLoader
from module.learner import Learner
from module.load_and_split_mnist_dataset import concat_data
from info import model_comments, model_indices, base_learn_epochs

NUM_WORKERS = 4

for i, index in enumerate(model_indices):
    td, vd = concat_data(index)
    tdl = DataLoader(td, batch_size=64, shuffle=True, num_workers=NUM_WORKERS)
    vdl = DataLoader(vd, batch_size=64, shuffle=True, num_workers=NUM_WORKERS)
    learner = Learner(tdl,
                      vdl,
                      lr=0.005,
                      log_interval=100,
                      tensorboard=True,
                      tensorboard_comment=model_comments[i])
    learner.learn(base_learn_epochs)
    learner.save(f"model/subenv_{model_comments[i]}")
    print(f"save subenv_{model_comments[i]}")
    del learner