def normal_train(args, net, optimizer, loss_func, ldr_train, ldr_valid): epoch_loss = [] best_valid_loss = np.finfo(float).max best_valid_net = copy.deepcopy(net) net.train() for iter in range(args.local_ep): batch_loss = [] for batch_idx, (attributes, labels) in enumerate(ldr_train): attributes, labels = attributes.to(args.device), labels.to( device=args.device, dtype=torch.long) optimizer.zero_grad() log_probs = net(attributes) loss = loss_func(log_probs, labels) loss.backward() # print(len(attributes), len(labels), len(loss)) optimizer.step() # if args.verbose and batch_idx % 10 == 0: # print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( # iter, batch_idx * len(attributes), len(ldr_train.dataset), # 100. * batch_idx / len(ldr_train), loss.item())) batch_loss.append(loss.item()) epoch_loss.append(sum(batch_loss) / len(batch_loss)) net.eval() _, tmp_loss_valid = test_bank(net, ldr_valid, args) if tmp_loss_valid < best_valid_loss: best_valid_net = copy.deepcopy(net) net.train() return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
for iter in range(args.epochs): batch_loss = [] for batch_idx, (data, target) in enumerate(train_loader): # data, target = data.to(args.device), target.to(args.device, dtype=torch.long) optimizer.zero_grad() output = net_glob(data) loss = loss_func(output, target) loss.backward() optimizer.step() batch_loss.append(loss.item()) loss_avg = sum(batch_loss) / len(batch_loss) print('Round{:3d}, Average loss {:.3f}'.format(iter, loss_avg)) loss_train.append(loss_avg) net_glob.eval() acc_valid, tmp_loss_valid = test_bank(net_glob, valid_loader, args) print('Round{:3d}, Validation loss {:.3f}'.format( iter, tmp_loss_valid)) loss_valid.append(tmp_loss_valid) if tmp_loss_valid < best_valid_loss: best_valid_loss = tmp_loss_valid best_net_glob = copy.deepcopy(net_glob) print('SAVE BEST MODEL AT EPOCH {}'.format(iter)) net_glob.train() if args.dp: privacy_engine.detach() torch.save(best_net_glob, save_prefix + '_best.pt') torch.save(net_glob, save_prefix + '_final.pt')
def server(cur_net, current_iter, current_server_rank_id, best_valid_loss, best_net_glob, server_flag): loss_locals = [] w_state_dict_locals = [] # local train cur_net.train() optimizer = get_optimizer(args, cur_net) loss_func = nn.CrossEntropyLoss() if args.dp: privacy_engine = PrivacyEngine(cur_net, batch_size=args.bs, sample_size=len(local_train_loader), alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)), noise_multiplier=0.3, max_grad_norm=1.2, secure_rng=args.secure_rng) privacy_engine.attach(optimizer) current_state_dict, current_loss = normal_train(args, cur_net, optimizer, loss_func, local_train_loader, valid_loader) if args.dp: privacy_engine.detach() loss_locals.append(current_loss) if args.tphe: w_state_dict_locals.append(encrypt_torch_state_dict(pub_key, current_state_dict)) else: w_state_dict_locals.append(current_state_dict) # receive from others loop = True while loop: # Get the list sockets which are ready to be read through select rList, wList, error_sockets = select.select(server_connection_list,[],[]) for sockfd in rList: tmp_pkl_data = sockfd.recv(int(args.buffer)) # 760586945 tmp_state_dict, tmp_loss = pickle.loads(tmp_pkl_data) w_state_dict_locals.append(tmp_state_dict) loss_locals.append(tmp_loss) if len(w_state_dict_locals) == args.num_users: loop = False break # aggregate weight state_dicts aggregated_state_dict = state_dict_aggregation(w_state_dict_locals) # distribute the aggregated weight state_dict send_aggregated_weight_state_dict_to_all(aggregated_state_dict) # parse aggregated state_dict parse_aggregated_state_dict(aggregated_state_dict, cur_net) loss_avg = sum(loss_locals) / len(loss_locals) print('Round{:3d}, Average loss {:.3f}'.format(current_iter, loss_avg)) loss_train.append(loss_avg) cur_net.eval() acc_valid, tmp_loss_valid = test_bank(cur_net, valid_loader, args) print('Round{:3d}, Validation loss {:.3f}'.format(current_iter, tmp_loss_valid)) loss_valid.append(tmp_loss_valid) if tmp_loss_valid < best_valid_loss: best_valid_loss = tmp_loss_valid best_net_glob = copy.deepcopy(cur_net) print('SAVE BEST MODEL AT EPOCH {}'.format(current_iter)) # pick the server for next epoch next_server_rank_id = random.randint(0, args.num_users-1) # distribute metadata send_metadata_to_all(loss_avg, tmp_loss_valid, next_server_rank_id) if next_server_rank_id != args.rank: server_flag = False current_server_rank_id = next_server_rank_id print("\33[31m\33[1m Current server rank id {} \33[0m".format(current_server_rank_id)) return cur_net, current_server_rank_id, best_valid_loss, best_net_glob, server_flag