def run(rank, world_size, loss_train, acc_train, dataset_train, idxs_users, net_glob, grc): # net_glob.load_state_dict(torch.load('net_state_dict.pt')) if rank == 0: #compressor, epoch, dgc foldername = f'{args.compressor}epoch{args.epochs}ratio{args.gsr}' tb = SummaryWriter("runs/" + foldername) round = 0 for i in idxs_users: #for each epoch idx = dict_users[i[rank]] epoch_loss = torch.zeros(1) optimizer = torch.optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum) local = LocalUpdate(args=args, dataset=dataset_train, idxs=idx) #create LocalUpdate class train_loss = local.train(net=net_glob) #train local for index, (name, parameter) in enumerate(net_glob.named_parameters()): grad = parameter.grad.data grc.acc(grad) new_tensor = grc.step(grad, name) grad.copy_(new_tensor) optimizer.step() net_glob.zero_grad() epoch_loss += train_loss dist.reduce(epoch_loss, 0, dist.ReduceOp.SUM) net_glob.eval() train_acc = torch.zeros(1) acc, loss = local.inference(net_glob, dataset_train, idx) train_acc += acc dist.reduce(train_acc, 0, dist.ReduceOp.SUM) if rank == 0: torch.save(net_glob.state_dict(), 'net_state_dict.pt') epoch_loss /= world_size train_acc /= world_size loss_train[round] = epoch_loss[0] acc_train[round] = train_acc[0] tb.add_scalar("Loss", epoch_loss[0], round) tb.add_scalar("Accuracy", train_acc[0], round) tb.add_scalar("Uncompressed Size", grc.uncompressed_size, round) tb.add_scalar("Compressed Size", grc.size, round) if round % 50 == 0: print('Round {:3d}, Rank {:1d}, Average loss {:.6f}, Average Accuracy {:.2f}%'.format(round, dist.get_rank(), epoch_loss[0], train_acc[0])) round+=1 if rank == 0: tb.close() print("Printing Compression Stats...") grc.printr()
def run(rank, world_size, loss_train, acc_train, epoch, dataset_train, idx, net_glob): net_glob.load_state_dict(torch.load('net_state_dict.pt')) dgc_trainer = DGC(model=net_glob, rank=rank, size=world_size, momentum=args.momentum, full_update_layers=[4], percentage=args.dgc) dgc_trainer.load_state_dict(torch.load('dgc_state_dict.pt')) epoch_loss = torch.zeros(1) for iter in range(args.local_ep): local = LocalUpdate(args=args, dataset=dataset_train, idxs=idx) #create LocalUpdate class b_loss = local.train(net=net_glob, world_size=world_size, rank=rank) #train local epoch_loss += b_loss if rank == 0: print("Local Epoch: {}, Local Epoch Loss: {}".format(iter, b_loss)) dgc_trainer.gradient_update() epoch_loss /= args.local_ep dist.reduce(epoch_loss, 0, dist.ReduceOp.SUM) net_glob.eval() train_acc = torch.zeros(1) local = LocalUpdate(args=args, dataset=dataset_train, idxs=idx) #create LocalUpdate class acc, loss = local.inference(net_glob, dataset_train, idx) train_acc += acc dist.reduce(train_acc, 0, dist.ReduceOp.SUM) if rank == 0: torch.save(net_glob.state_dict(), 'net_state_dict.pt') torch.save(dgc_trainer.state_dict(), 'dgc_state_dict.pt') epoch_loss /= world_size train_acc /= world_size loss_train[epoch] = epoch_loss[0] acc_train[epoch] = train_acc[0] print( 'Round {:3d}, Rank {:1d}, Average loss {:.6f}, Average Accuracy {:.2f}%' .format(epoch, dist.get_rank(), epoch_loss[0], train_acc[0]))