def multi_rounds_fedavg(avged_weight, train_data, test_data, global_test_data, global_test_label, n_clients, factors, lr, comm_round):
    for cr in range(comm_round):
        logger.info("Start commround of FedAvg {} ....".format(cr))
        nets_list = []
        for client_index in range(n_clients):
            logger.info("Start local training process for client {}, communication round: {} ....".format(client_index, cr))

            client_user_name = TRIAL_USER_NAME[client_index]
            num_samples_train = len(train_data["user_data"][client_user_name]['x'])
            num_samples_test = len(test_data["user_data"][client_user_name]['x'])

            #num_samples_list.append(num_samples_train)
            user_train_data = train_data["user_data"][client_user_name]
            user_test_data = test_data["user_data"][client_user_name]

            model = language_model.RNNModel('LSTM', 80, 8, 256, 1, 0.2, tie_weights=False).to(device)

            #### we need to load the prev avg weight to the model
            new_state_dict = {}
            for param_idx, (key_name, param) in enumerate(model.state_dict().items()):
                temp_dict = {key_name: torch.from_numpy(avged_weight[param_idx])}
                new_state_dict.update(temp_dict)
            model.load_state_dict(new_state_dict)
            ####
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
            total_loss = 0.0
            for epoch in range(5):
                model.train()
                epoch_start_time = time.time()

                hidden_train = model.init_hidden(BATCH_SIZE)
                for i in range(int(num_samples_train / BATCH_SIZE)):
                    input_data, target_data = process_x(user_train_data['x'][BATCH_SIZE*i:BATCH_SIZE*(i+1)]), process_y(user_train_data['y'][BATCH_SIZE*i:BATCH_SIZE*(i+1)])
      
                    data, targets = torch.from_numpy(input_data).to(device), torch.from_numpy(target_data).to(device)
                    optimizer.zero_grad()

                    hidden_train = repackage_hidden(hidden_train)
                    output, hidden_train = model(data, hidden_train)

                    loss = criterion(output.t(), torch.max(targets, 1)[1])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                    cur_loss = total_loss
                    total_loss = 0

            eval_batch_size = 10
            model.eval()
            total_val_loss = 0.
            ntokens = 80
            hidden_test = model.init_hidden(eval_batch_size)
            correct_prediction = 0
            with torch.no_grad():
                for i in range(int(num_samples_test / eval_batch_size)):
                    input_data, target_data = process_x(user_test_data['x'][eval_batch_size*i:eval_batch_size*(i+1)]), process_y(user_test_data['y'][eval_batch_size*i:eval_batch_size*(i+1)])
                    data, targets = torch.from_numpy(input_data).to(device), torch.from_numpy(target_data).to(device)

                    hidden_test = repackage_hidden(hidden_test)
                    output, hidden_test = model(data, hidden_test)
                    loss = criterion(output.t(), torch.max(targets, 1)[1])
                    _, pred_label = torch.max(output.t(), 1)
                    correct_prediction += (pred_label == torch.max(targets, 1)[1]).sum().item()

                    total_val_loss += loss.item()
            logger.info('-' * 89)
            logger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | pred: {}/{} | acc: {:.4f}%'.format(epoch, (time.time() - epoch_start_time),
                                                   total_val_loss, correct_prediction, num_samples_test, correct_prediction/num_samples_test*100.0))
            logger.info('-' * 89)

            nets_list.append(model)

        collected_weights = collect_weights(nets_list)
        avged_weight = fed_avg(collected_weights, factors)
        # we now test the performance of the avged model
        new_state_dict = {}
        avged_model = language_model.RNNModel('LSTM', 80, 8, 256, 1, 0.2, tie_weights=False) # we now start from 1 layer
        for param_idx, (key_name, param) in enumerate(avged_model.state_dict().items()):
            temp_dict = {key_name: torch.from_numpy(avged_weight[param_idx])}
            new_state_dict.update(temp_dict)
        avged_model.load_state_dict(new_state_dict)
        # we will need to construct a new global test set based on all test data on each of the clients

        global_eval_batch_size = 10
        ######################################################
        # here we measure the performance of averaged model
        ######################################################
        logger.info("We measure the performamce of FedAvg here ....")
        avged_model.to(device)
        avged_model.eval()
        total_val_loss = 0.
        ntokens = 80
        hidden_test = avged_model.init_hidden(global_eval_batch_size)
        global_correct_prediction = 0
        with torch.no_grad():
            for i in range(int(global_num_samples_test / global_eval_batch_size)):
                input_data, target_data = process_x(global_test_data[global_eval_batch_size*i:global_eval_batch_size*(i+1)]), process_y(global_test_label[global_eval_batch_size*i:global_eval_batch_size*(i+1)])
                data, targets = torch.from_numpy(input_data).to(device), torch.from_numpy(target_data).to(device)
                hidden_test = repackage_hidden(hidden_test)
                output, hidden_test = avged_model(data, hidden_test)
                loss = criterion(output.t(), torch.max(targets, 1)[1])
                _, pred_label = torch.max(output.t(), 1)
                global_correct_prediction += (pred_label == torch.max(targets, 1)[1]).sum().item()
                total_val_loss += loss.item()
        logger.info('*' * 89)
        logger.info('|FedAvg-ACC Comm Round: {} | On Global Testset | valid loss {:5.2f} | pred: {}/{} | acc: {:.4f}%'.format(cr, total_val_loss, global_correct_prediction, global_num_samples_test, global_correct_prediction/global_num_samples_test*100.0))
        logger.info('*' * 89)
    num_samples_list = [len(train_data["user_data"][TRIAL_USER_NAME[client_index]]['x']) for client_index in range(n_clients)]
    nets_list = []
    criterion = nn.CrossEntropyLoss()

    for client_index in range(n_clients):
        if retrain_flag:
            logger.info("Start local training process for client {} ....".format(client_index))
            client_user_name = TRIAL_USER_NAME[client_index]
            num_samples_train = len(train_data["user_data"][client_user_name]['x'])
            num_samples_test = len(test_data["user_data"][client_user_name]['x'])

            user_train_data = train_data["user_data"][client_user_name]
            user_test_data = test_data["user_data"][client_user_name]

            model = language_model.RNNModel('LSTM', 80, 8, 256, 1, 0.2, tie_weights=False).to(device)
            #optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001, amsgrad=True)
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)

            for epoch in range(TRIAL_EPOCH):
                model.train()
                epoch_start_time = time.time()

                hidden_train = model.init_hidden(BATCH_SIZE)
                for i in range(int(num_samples_train / BATCH_SIZE)):
                    input_data, target_data = process_x(user_train_data['x'][BATCH_SIZE*i:BATCH_SIZE*(i+1)]), process_y(user_train_data['y'][BATCH_SIZE*i:BATCH_SIZE*(i+1)])
      
                    data, targets = torch.from_numpy(input_data).to(device), torch.from_numpy(target_data).to(device)
                    optimizer.zero_grad()

                    hidden_train = repackage_hidden(hidden_train)
Example #3
0
                    for layer_index, layer_weight in enumerate(
                            collected_weights[client_index]):
                        if layer_index == i:
                            #temp_retrained_weights.append(matched_weights[0].T)
                            temp_retrained_weights.append(avg_encoding_weights)
                        elif layer_index == (i + 1):
                            temp_retrained_weights.append(
                                patched_next_layer_weights)
                        else:
                            temp_retrained_weights.append(layer_weight)

                    retrain_model = language_model.RNNModel(
                        'LSTM',
                        80,
                        next_layer_shape,
                        256,
                        1,
                        0.2,
                        tie_weights=False)  # we now start from 1 layer
                elif i == 1:
                    # we now get the match results based on the w_{i.} matrices we will need to use this to permutate three dimensions
                    # i) col of w_{h.} matrices ii) row of w_{h.} matrices iii) input shape of the next layer
                    # we process the next layer first:
                    next_layer_weights = collected_weights[client_index][i + 4]
                    patched_next_layer_weights = patch_weights(
                        next_layer_weights, next_layer_shape,
                        assignments[client_index])

                    temp_retrained_weights = [
                        collected_weights[client_index][0], avg_i_weights,
                        avg_h_weights, avg_i_bias, avg_h_bias,
Example #4
0
def local_retraining_process(n_clients, train_data, test_data,
                             res_local_weight_list, lr, user_names):
    nets_list = []
    # construct a global test set
    global_test_data = []
    global_test_label = []
    global_num_samples_test = 0

    for client_index in range(n_clients):
        client_user_name = user_names[client_index]
        global_num_samples_test += len(
            test_data["user_data"][client_user_name]['x'])
        global_test_data += test_data["user_data"][client_user_name]['x']
        global_test_label += test_data["user_data"][client_user_name]['y']
    global_eval_batch_size = 10

    for client_index in range(n_clients):
        logger.info("Start local training process for client {} ....".format(
            client_index))
        client_user_name = user_names[client_index]
        num_samples_train = len(train_data["user_data"][client_user_name]['x'])
        num_samples_test = len(test_data["user_data"][client_user_name]['x'])

        user_train_data = train_data["user_data"][client_user_name]
        user_test_data = test_data["user_data"][client_user_name]
        model = language_model.RNNModel('LSTM',
                                        80,
                                        8,
                                        256,
                                        1,
                                        0.2,
                                        tie_weights=False)
        criterion = nn.CrossEntropyLoss()
        ##### load the weights to the model
        ####################################
        new_state_dict = {}
        for param_idx, (key_name,
                        param) in enumerate(model.state_dict().items()):
            #logger.info("param idx: {}, para shape: {}, client_index: {}, len res: {}".format(param_idx, param.shape, client_index, len(res_local_weight_list)))
            temp_dict = {
                key_name:
                torch.from_numpy(
                    res_local_weight_list[client_index][param_idx])
            }
            new_state_dict.update(temp_dict)
        model.load_state_dict(new_state_dict)
        model.to(device)

        #######################################################
        # evaluate the model performance on global eval dataset
        #######################################################
        model.eval()

        eval_batch_size = 10
        total_val_loss = 0.
        ntokens = 80
        hidden_test = model.init_hidden(eval_batch_size)
        correct_prediction = 0
        with torch.no_grad():
            for i in range(int(num_samples_test / eval_batch_size)):
                input_data, target_data = process_x(
                    user_test_data['x'][eval_batch_size * i:eval_batch_size *
                                        (i + 1)]), process_y(
                                            user_test_data['y']
                                            [eval_batch_size *
                                             i:eval_batch_size * (i + 1)])
                data, targets = torch.from_numpy(input_data).to(
                    device), torch.from_numpy(target_data).to(device)

                hidden_test = repackage_hidden(hidden_test)
                output, hidden_test = model(data, hidden_test)
                loss = criterion(output.t(), torch.max(targets, 1)[1])
                _, pred_label = torch.max(output.t(), 1)
                correct_prediction += (pred_label == torch.max(
                    targets, 1)[1]).sum().item()

                total_val_loss += loss.item()
        logger.info('-' * 89)
        logger.info(
            '| Local Global Testset | client index: {} | valid loss {:5.2f} | pred: {}/{} | acc: {:.4f}%'
            .format(client_index, total_val_loss, correct_prediction,
                    num_samples_test,
                    correct_prediction / num_samples_test * 100.0))
        logger.info('-' * 89)

        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              momentum=0.9,
                              weight_decay=0.0001)
        total_loss = 0.0
        for epoch in range(TRIAL_EPOCH):
            model.train()
            epoch_start_time = time.time()

            hidden_train = model.init_hidden(BATCH_SIZE)
            for i in range(int(num_samples_train / BATCH_SIZE)):
                input_data, target_data = process_x(
                    user_train_data['x']
                    [BATCH_SIZE * i:BATCH_SIZE * (i + 1)]), process_y(
                        user_train_data['y'][BATCH_SIZE * i:BATCH_SIZE *
                                             (i + 1)])

                data, targets = torch.from_numpy(input_data).to(
                    device), torch.from_numpy(target_data).to(device)
                optimizer.zero_grad()

                hidden_train = repackage_hidden(hidden_train)
                output, hidden_train = model(data, hidden_train)

                loss = criterion(output.t(), torch.max(targets, 1)[1])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                cur_loss = total_loss

                total_loss = 0
                start_time = time.time()

            eval_batch_size = 10
            model.eval()
            total_val_loss = 0.
            ntokens = 80
            hidden_test = model.init_hidden(eval_batch_size)
            correct_prediction = 0
            with torch.no_grad():
                for i in range(int(num_samples_test / eval_batch_size)):
                    input_data, target_data = process_x(
                        user_test_data['x']
                        [eval_batch_size * i:eval_batch_size *
                         (i + 1)]), process_y(
                             user_test_data['y'][eval_batch_size *
                                                 i:eval_batch_size * (i + 1)])
                    data, targets = torch.from_numpy(input_data).to(
                        device), torch.from_numpy(target_data).to(device)

                    hidden_test = repackage_hidden(hidden_test)
                    output, hidden_test = model(data, hidden_test)
                    loss = criterion(output.t(), torch.max(targets, 1)[1])
                    _, pred_label = torch.max(output.t(), 1)
                    correct_prediction += (pred_label == torch.max(
                        targets, 1)[1]).sum().item()

                    total_val_loss += loss.item()
            logger.info('-' * 89)
            logger.info(
                '| end of epoch {:3d} | valid loss {:5.2f} | pred: {}/{} | acc: {:.4f}%'
                .format(epoch, total_val_loss, correct_prediction,
                        num_samples_test,
                        correct_prediction / num_samples_test * 100.0))
            logger.info('-' * 89)
        nets_list.append(model)
    return nets_list
Example #5
0
    for client_index in range(n_clients):
        client_user_name = TRIAL_USER_NAME[client_index]

        global_num_samples_test += len(test_data["user_data"][client_user_name]['x'])
        global_test_data += test_data["user_data"][client_user_name]['x']
        global_test_label += test_data["user_data"][client_user_name]['y']

        global_num_samples_train += len(train_data["user_data"][client_user_name]['x'])
        global_train_data += train_data["user_data"][client_user_name]['x']
        global_train_label += train_data["user_data"][client_user_name]['y']

    global_eval_batch_size = 10

    logger.info("Start training over the entire dataset ....")

    model = language_model.RNNModel('LSTM', 80, 8, 256, 1, 0.2, tie_weights=False).to(device)
    #optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001, amsgrad=True)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)

    for epoch in range(TRIAL_EPOCH):
        model.train()
        epoch_start_time = time.time()

        hidden_train = model.init_hidden(BATCH_SIZE)
        for i in range(int(global_num_samples_train / BATCH_SIZE)):
            input_data, target_data = process_x(global_train_data[BATCH_SIZE*i:BATCH_SIZE*(i+1)]), process_y(global_train_label[BATCH_SIZE*i:BATCH_SIZE*(i+1)])

            data, targets = torch.from_numpy(input_data).to(device), torch.from_numpy(target_data).to(device)
            optimizer.zero_grad()

            hidden_train = repackage_hidden(hidden_train)