def run_master(config):
    w_i, num_worker = config.site_id, config.num_worker
    beta, S, tau, steps = config.beta, config.S, config.tau, config.steps
    device = config.device

    X = load_data(w_i, num_worker)
    x_dim = tuple(X.shape[1:])

    master = AvgMaster(X, num_worker, x_dim, beta, S, tau, device)

    addrport = ('localhost', config.port)
    server = Server(addrport, master.send_queue, master.recv_queue, num_worker)
    server.accept()

    recv_thread = threading.Thread(target=server.recv_loop, daemon=True)
    recv_thread.start()

    # For logging
    import pandas as pd
    log_dir = f'./logs/mnist_avg/S={S} tau={tau}'
    os.makedirs(log_dir, exist_ok=True)

    time_vals = []
    z_vals = torch.zeros((steps, ) + x_dim, device=device)
    print('Started computing!')
    time_0 = time.time()
    for _ in server.send_iter():
        if master.stop:
            print('Optimization Ended, calculating objective values.')
            data = []
            for step in range(steps):
                obj = master.objective(z_vals[step])
                data.append([step + 1, time_vals[step], obj.item()])

            dataframe = pd.DataFrame(
                data, columns=['global_step', 'wall_time', 'objective'])
            dataframe.to_pickle(log_dir + '/logs.pkl')
            dataframe.to_csv(log_dir + '/logs.csv')

            np.save(log_dir + '/results.npy', master.z.data.cpu().numpy())
            print('Done!')

            break

        master.receive()
        if master.update():
            if master.k == 1: time_0 = time.time()
            z_vals[master.k - 1] = master.z
            time_vals.append(time.time() - time_0)

        if master.k == steps:
            master.stop_algorithm()
def run_master(config):
    w_i, num_worker = config.site_id, config.num_worker
    beta, S, tau, steps = config.beta, config.S, config.tau, config.steps
    device = config.device

    X, Y, X_test, Y_test, xm, xd = load_data(w_i, num_worker)
    x_dim = (MNIST_SHAPE + 1, 10)

    master = MCMaster(X, Y, X_test, Y_test, num_worker, x_dim, beta, S, tau,
                      device)

    addrport = ('localhost', config.port)
    server = Server(addrport, master.send_queue, master.recv_queue, num_worker)
    server.accept()

    recv_thread = threading.Thread(target=server.recv_loop, daemon=True)
    recv_thread.start()

    # For logging
    import pandas as pd
    log_dir = f'./logs/mnist_logistic/S={S} tau={tau}'
    os.makedirs(log_dir, exist_ok=True)

    time_vals = []
    z_vals = torch.zeros((steps, ) + x_dim, device=device)
    print('Started computing!')
    time_0 = time.time()
    for _ in server.send_iter():
        if master.stop:
            print('Optimization Ended, calculating objective values.')
            data = []
            for step in range(steps):
                train_loss, test_loss = master.objective(z_vals[step])

                data.append([
                    step + 1, time_vals[step],
                    train_loss.item(),
                    test_loss.item()
                ])

            dataframe = pd.DataFrame(data,
                                     columns=[
                                         'global_step', 'wall_time',
                                         'train_loss', 'test_loss'
                                     ])
            dataframe.to_pickle(log_dir + '/logs.pkl')
            dataframe.to_csv(log_dir + '/logs.csv')

            w, b = W2wb(master.z)
            save_results(log_dir + '/results.npz', w=w, b=b, xm=xm, xd=xd)
            print('Done!')
            break

        master.receive()
        if master.update():
            if master.k == 1: time_0 = time.time()
            z_vals[master.k - 1] = master.z
            time_vals.append(time.time() - time_0)

        if master.k == steps:
            master.stop_algorithm()