def test_kmean():
    train_set, valid_set, test_set = DatasetManager.read_dataset(dataset_name="dataset_simulation_20.csv", shared=False)
    kmean = KMean()

    _clusters, _centers = kmean.run(
        dataset=train_set[0],
        n_clusters=5,
        max_iters=100,
        threshold=1.0
    )

    assert _clusters
    assert _centers is not None
def run_experiment():
    from datasets import DatasetManager
    from preprocessing.scaling import get_gaussian_normalization
    from preprocessing.dimensionality_reduction import get_pca

    train_set, valid_set, test_set = DatasetManager.read_dataset()
    dataset, result = train_set

    # Reduce to a 2D dimensionality for plotting the data
    dataset = get_gaussian_normalization(dataset)
    # dataset = get_LLE(dataset, num_components=2, n_neighbors=80)
    dataset, explained_variance_ratio_ = get_pca(dataset, num_components=2)

    plot_embedding(dataset, result)
Exemple #3
0
def run(rank, size):
    global fl_round
    global rat_per_class
    # Minimizes MSE
    adversarial_loss = torch.nn.MSELoss()
    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Configure data loader
    same_data = False  #set to True if all devices are required to hold the same data
    if same_data:
        os.makedirs("../data/mnist", exist_ok=True)
        train_set = torch.utils.data.DataLoader(
            datasets.MNIST(
                "../data/mnist",
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(opt.img_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]),
            ),
            batch_size=opt.batch_size,
            shuffle=True,
        )
    else:
        manager = DatasetManager(opt.model, opt.batch_size, opt.img_size,
                                 size - 1, size, rank, opt.iid, 1)
        train_set, _ = manager.get_train_set(opt.magic_num)

    init_groups(size)
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    #For FID calculations
    if rank == 0:
        fic_model = InceptionV3()
        if cuda:
            fic_model = fic_model.cuda()
        test_set = manager.get_test_set()
        for i, t in enumerate(test_set):
            test_imgs = t[0].cuda() if cuda else t[0]
            test_labels = t[1]

    # ----------
    #  Training
    # ----------
    #DIST
    elapsed_time = time()
    num_batches = 0  #This variable acts as a global state variable to sync. between workers and the server
    done_round = True
    group = None
    #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue
    est_len = 50000 // (
        size * opt.batch_size
    )  #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers
    act_len = len(train_set)
    if act_len < est_len:
        opt.n_epochs = int(opt.n_epochs * (est_len / act_len))
    imgs = []
    for i, (tmps, _) in enumerate(train_set):
        imgs = tmps
        break
    for epoch in range(opt.n_epochs):
        broadcast_model(generator, elapsed_time=elapsed_time)
        fl_round += 1
        num_batches += 1
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0),
                         requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0),
                        requires_grad=False)

        #HINT: training the generator is not required on the server, yet PyTorch requires it. It does not affect the runtime anyway

        # -----------------
        #  Train Generator
        # -----------------
        z = Variable(
            Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        temp = generator(z)
        if rank == 0:  #MD-GAN trains the generator only on the server
            optimizer_G.zero_grad()
            # Sample noise as generator input
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))
            # Generate a batch of images
            X_g = generator(z)
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))
            # Generate a batch of images
            X_d = generator(z)
            for n in range(size - 1):
                # Sample noise as generator input
                # Generate a batch of images
                dist.broadcast(tensor=X_g, src=0, group=all_groups[n])
                # Generate a batch of images
                dist.broadcast(tensor=X_d, src=0, group=all_groups[n])

        else:  #First, workers receive generated batches by the server
            X_g = torch.zeros(temp.size())
            X_d = torch.zeros(temp.size())
            dist.broadcast(tensor=X_g, src=0, group=all_groups[rank - 1])
            dist.broadcast(tensor=X_d, src=0, group=all_groups[rank - 1])
            if cuda:
                X_g = X_g.cuda()
                X_d = X_d.cuda()

            # Loss measures generator's ability to fool the discriminator
        if rank == 0:
            d_gen = discriminator(temp)
            g_loss = adversarial_loss(d_gen, valid)
            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

        if rank != 0:
            L = 12
            for iter, (imgs_t, _) in enumerate(train_set):
                real_imgs = Variable(imgs_t.type(Tensor))
                if real_imgs.size(
                )[0] != opt.batch_size:  #To avoid mismatch problems
                    continue
                optimizer_D.zero_grad()

                # Measure discriminator's ability to classify real from generated samples
                real_loss = adversarial_loss(discriminator(real_imgs), valid)
                fake_loss = adversarial_loss(discriminator(X_d.detach()), fake)
                d_loss = 0.5 * (real_loss + fake_loss)
                d_loss.backward()
                optimizer_D.step()
                if iter == L - 1:
                    break

            optimizer_G.zero_grad()
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))
            X_g = generator(z)
            g_loss = adversarial_loss(discriminator(X_g), valid)
            g_loss.backward()
            optimizer_G.step()
        average_models(generator, elapsed_time=elapsed_time)
        del X_g
        del X_d

        #Print stats and generate images only if this is the server
        batches_done = fl_round
        if rank == 0 and fl_round % 20 == 0:
            print("Rank %d [Epoch %d/%d] [Batch %d/%d] time %f" %
                  (rank, epoch, opt.n_epochs, i, len(train_set),
                   time() - elapsed_time),
                  end=' ' if epoch != 0 else '\n')

            fid_z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (opt.fid_batch, opt.latent_dim))))
            gen_imgs = generator(fid_z)
            mu_gen, sigma_gen = calculate_activation_statistics(
                gen_imgs, fic_model)
            mu_test, sigma_test = calculate_activation_statistics(
                test_imgs[:opt.fid_batch], fic_model)
            fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test,
                                             sigma_test)
            print("FL-round {} FID Score: {}".format(fl_round, fid))
            sys.stdout.flush()
Exemple #4
0
def run(rank, size):
    global fl_round
    global rat_per_class
    # !!! Minimizes MSE instead of BCE
    adversarial_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Configure data loader
    #DIST (fix the path of data)
    manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1,
                             size, rank, opt.iid)
    train_set, _ = manager.get_train_set(opt.max_samples)

    lbl_count = [0 for _ in range(10)]
    for i, (imgs, lbls) in enumerate(train_set):
        for lbl in lbls:
            lbl_count[lbl.item()] += 1

    #This piece of info should be gathered at the server (to do informative decision about sampling)
    workers_classes = gather_lbl_count(lbl_count)
    if rank == 0:
        print(workers_classes)
    num_per_class = [
        5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949
    ]  #Aggregate number of classes is calculated manually here
    all_samples = sum(num_per_class)
    rat_per_class = [float(n / all_samples) for n in num_per_class]
    #Calculating entropy at this worker

    #Now, initializing all groups for the whole training process
    #    gp_t = time()
    init_groups(size, workers_classes)
    print("Rank {} Done initializing {} groups".format(rank, len(all_groups)))
    #    if opt.bench:
    #        print("Time to init the groups: ", time() - gp_t)
    #Calculating entropy of each worker (on the server side) based on these frequencies....
    if rank == 0:
        entropies = [
            stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) *
            (sum(freq_l) / all_samples) for freq_l in workers_classes
        ]
        print("Entropies are: ", entropies)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    #For FID calculations
    if rank == 0:
        fic_model = InceptionV3()
        if cuda:
            fic_model = fic_model.cuda()
        test_set = manager.get_test_set()
        for i, t in enumerate(test_set):
            test_imgs = t[0].cuda()
            test_labels = t[1]
        grouped_test_imgages = [[] for i in range(10)]
        for i, img in enumerate(test_imgs):
            grouped_test_imgages[test_labels[i]].append(img)
        for i, arr in enumerate(grouped_test_imgages):
            grouped_test_imgages[i] = torch.stack(arr)

    # ----------
    #  Training
    # ----------
    #DIST
    elapsed_time = time()
    num_batches = 0  #This variable acts as a global state variable to sync. between workers and the server
    done_round = True
    group = None
    #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue
    est_len = 50000 // (
        size * opt.batch_size
    )  #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers
    act_len = len(train_set)
    if act_len < est_len:
        opt.n_epochs = int(opt.n_epochs * (est_len / act_len))
    for epoch in range(opt.n_epochs):
        for i, (imgs, _) in enumerate(train_set):
            #DIST
            if done_round:  #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s)
                #First step: Choose a group of nodes to do computations in this round....
                fl_round += 1
                g = all_groups_np[fl_round % len(all_groups)]
                group = all_groups[fl_round % len(all_groups)]
                choose_r0 = False
                if rank == 0:
                    choose_r0 = choose_r[fl_round % len(all_groups)]
#                broad_t = time()
                if rank in g:
                    broadcast_model(generator, group, elapsed_time)
                    broadcast_model(discriminator, group, elapsed_time)
                    done_round = False
                else:  #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server
                    done_round = True
                    num_batches = num_batches + opt.local_steps  #Advance the pointer for workers that will not work this round
                    continue
            num_batches += 1
            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0),
                            requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # -----------------
            #  Train Generator
            # -----------------
            #            gen_t = time()
            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            #            gd_t = time()
            d_gen = discriminator(gen_imgs)
            g_loss = adversarial_loss(d_gen, valid)

            g_loss.backward()
            #            if opt.bench and rank == 0:
            #                print("Time of bakward pass 1 for discriminator ", time() - gd_t)

            #DIST
            #            g_avg_t = time()
            #Averaging step.......added because of distributed setup now!
            if num_batches % opt.local_steps == 0 and num_batches > 0:
                if opt.weight_avg:
                    #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker
                    cur_gp = all_groups_np[fl_round % len(all_groups)]
                    if rank == 0:
                        weights = [entropies[int(wrk)] for wrk in cur_gp]
                    else:  #dummy else
                        weights = [1.0 / len(cur_gp) for _ in cur_gp]
                    average_models(
                        generator,
                        group,
                        choose_r0,
                        weights,
                        elapsed_time=elapsed_time
                    )  #Experiments show that doing this is bad anyway!
                else:
                    average_models(generator,
                                   group,
                                   choose_r0,
                                   elapsed_time=elapsed_time)
                done_round = True
            if rank == 0 and not choose_r0:
                g_p = generator.parameters()
                for param in generator.parameters():
                    param.grad.data = torch.zeros(param.size()).cuda()

            optimizer_G.step()
            # ---------------------
            #  Train Discriminator
            # ---------------------

            #            disc_t = time()
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = 0.5 * (real_loss + fake_loss)

            d_loss.backward()

            #DIST
            #Averaging step.......added because of distributed setup now!
            #            d_avg_t = time()
            if num_batches % opt.local_steps == 0 and num_batches > 0:
                if opt.weight_avg:
                    average_models(discriminator,
                                   group,
                                   choose_r0,
                                   weights,
                                   elapsed_time=elapsed_time)
                else:
                    average_models(discriminator,
                                   group,
                                   choose_r0,
                                   elapsed_time=elapsed_time)
                done_round = True
            if rank == 0 and not choose_r0:
                for param in discriminator.parameters():
                    param.grad.data = torch.zeros(param.size()).cuda()
            optimizer_D.step()

            #Print stats and generate images only if this is the server
            batches_done = epoch * len(train_set) + i
            if rank == 0 and batches_done % opt.sample_interval == 0:
                print(
                    "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f"
                    % (rank, epoch, opt.n_epochs, i, len(train_set),
                       d_loss.item(), g_loss.item(), time() - elapsed_time),
                    end=' ' if epoch != 0 else '\n')
                #                sys.stdout.flush()

                # Evaluation setp => output images and calculate FID
                if batches_done % opt.sample_interval == 0 and batches_done != 0:
                    #                    pathname = os.path.abspath(os.path.dirname(sys.argv[0]))
                    #                    save_image(gen_imgs.data[:25], pathname+"/images-dist-s{}-w{}/{}-{}.png".format(opt.sample, opt.weight_avg, rank,batches_done), nrow=5, normalize=True)
                    #                    print("=====Calculating FID for round {}======".format(fl_round))
                    fid_z = Variable(
                        Tensor(
                            np.random.normal(0, 1,
                                             (opt.fid_batch, opt.latent_dim))))
                    del gen_imgs
                    gen_imgs = generator(fid_z)
                    mu_gen, sigma_gen = calculate_activation_statistics(
                        gen_imgs, fic_model)
                    mu_test, sigma_test = calculate_activation_statistics(
                        test_imgs[:opt.fid_batch], fic_model)
                    fid = calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_test, sigma_test)
                    print("FL-round {} FID Score: {}".format(fl_round, fid))
                    sys.stdout.flush()
                    if False:  #not opt.iid:
                        cur = 0
                        fids = [0 for i in range(10)]
                        for i, gp in enumerate(grouped_test_imgages):
                            mu_gen, sigma_gen = calculate_activation_statistics(
                                gen_imgs[cur:cur + len(gp)], fic_model)
                            cur += len(gp)
                            mu_test, sigma_test = calculate_activation_statistics(
                                gp, fic_model)
                            fids[i] = calculate_frechet_distance(
                                mu_gen, sigma_gen, mu_test, sigma_test)
                        print("avg: ", np.mean(fids), " max: ", np.max(fids),
                              " min: ", np.min(fids))
Exemple #5
0
def run(rank, size):
    global fl_round
    global rat_per_class
    NUM_CLASSES = 200 if opt.model == 'imagenet' else 10
    criterion = torch.nn.BCELoss()
    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(opt.batch_size, opt.latent_dim, 1, 1)
    if cuda:
        fixed_noise = fixed_noise.cuda()

    # Initialize generator and discriminator
    generator = Generator(1)
    generator.apply(weights_init)
    discriminator = Discriminator(1)
    discriminator.apply(weights_init)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        criterion.cuda()

    # Configure data loader
#DIST
    manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1,
                             size, rank, opt.iid, 1)
    train_set, _ = manager.get_train_set(opt.magic_num)

    lbl_count = [0 for _ in range(NUM_CLASSES)]
    all_labels = []
    for i, (imgs, lbls) in enumerate(train_set):
        for lbl in lbls:
            if lbl.item() not in all_labels:
                all_labels.append(lbl.item())
            lbl_count[lbl.item()] += 1
    workers_classes = gather_lbl_count(lbl_count)
    num_per_class = [500 for _ in range(NUM_CLASSES)]
    all_samples = sum(num_per_class)
    rat_per_class = [float(n / all_samples) for n in num_per_class]
    #Calculating entropy at this worker
    ent = stats.entropy(np.array(lbl_count) / sum(lbl_count), rat_per_class)

    #Now, initializing all groups for the whole training process
    print("Rank {} Start init groups".format(rank))
    sys.stdout.flush()
    init_groups(size, workers_classes)
    print("Rank {} Done initializing {} groups".format(rank, len(all_groups)))
    #Calculating entropy of each worker (on the server side) based on these frequencies....
    if rank == 0 and opt.weight_avg:
        entropies = [
            stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) *
            (sum(freq_l) / all_samples) for freq_l in workers_classes
        ]
#        print("Entropies are: ", entropies)

# Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    #For FID calculations
    if rank == 0:
        fic_model = InceptionV3()
        if cuda:
            fic_model = fic_model.cuda()
        test_set = manager.get_test_set()
        for i, t in enumerate(test_set):
            test_imgs = t[0].cuda() if cuda else t[0]
            test_labels = t[1]

    # ----------
    #  Training
    # ----------
    #DIST
    elapsed_time = time()
    weak_workers = []
    if weak_percent > 0.0:
        weak_workers = [i for i in range(1, size, round(1 / weak_percent))]
    print("Number of simulated weak workers: ", len(weak_workers))
    num_batches = 0  #This variable acts as a global state variable to sync. between workers and the server
    done_round = True
    group = None
    #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue
    est_len = 1000000 // (
        size * opt.batch_size
    )  #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers
    act_len = len(train_set)
    if act_len < est_len:
        opt.n_epochs = int(opt.n_epochs * (est_len / act_len))
    for epoch in range(opt.n_epochs):
        for i, (imgs, _) in enumerate(train_set):
            #DIST
            if done_round:  #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s)
                #First step: Choose a group of nodes to do computations in this round....
                fl_round += 1
                g = all_groups_np[fl_round % len(all_groups)]
                group = all_groups[fl_round % len(all_groups)]
                choose_r0 = False
                if rank == 0:
                    choose_r0 = choose_r[fl_round % len(all_groups)]
                if rank in g:
                    broadcast_model(generator, group)
                    broadcast_model(discriminator, group)
                    done_round = False
                else:  #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server
                    done_round = True
                    num_batches = num_batches + opt.local_steps  #Advance the pointer for workers that will not work this round
                    continue
            # Adversarial ground truths
            real_imgs = Variable(imgs.type(Tensor))
            valid = Variable(Tensor(real_imgs.size()[0], 1, 1, 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(real_imgs.size()[0], 1, 1, 1).fill_(0.0),
                            requires_grad=False)
            num_batches += 1

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = torch.randn(real_imgs.size()[0], opt.latent_dim, 1, 1)
            if cuda:
                z = z.cuda()
            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            d_gen = discriminator(gen_imgs)
            g_loss = criterion(d_gen, valid)
            g_loss.backward()

            #DIST
            #Averaging step.......added because of distributed setup now!
            local_steps = opt.local_steps
            if rank in weak_workers:
                local_steps = int(opt.local_steps / 2)
            if num_batches % local_steps == 0 and num_batches > 0:
                if opt.weight_avg:
                    #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker
                    cur_gp = all_groups_np[fl_round % len(all_groups)]
                    if rank == 0:
                        weights = [entropies[int(wrk)] for wrk in cur_gp]
                    else:  #dummy else
                        weights = [1.0 / len(cur_gp) for _ in cur_gp]
                #This weighting is orthogonal to KL-weighting scheme
                average_models(generator, group, choose_r0, weights)
                done_round = True
            if rank == 0 and not choose_r0:
                g_p = generator.parameters()
                for param in generator.parameters():
                    param.grad.data = torch.zeros(
                        param.size()).cuda() if cuda else torch.zeros(
                            param.size())

            optimizer_G.step()
            if rank == 0 and not choose_r0:
                for o, n in zip(g_p, generator.parameters()):
                    if not torch.eq(o, n).all():
                        print(
                            "Generator updated while it should not have been!!!! error here......."
                        )

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = criterion(discriminator(real_imgs), valid)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            d_loss = 0.5 * (real_loss + fake_loss)
            d_loss.backward()

            #DIST
            #Averaging step.......added because of distributed setup now!
            if num_batches % local_steps == 0 and num_batches > 0:
                #In the new version, we apply weights anyway.....to account for weak workers not only KL-divergence
                average_models(discriminator, group, choose_r0, weights)
                done_round = True

            if rank == 0 and not choose_r0:
                for param in discriminator.parameters():
                    param.grad.data = torch.zeros(
                        param.size()).cuda() if cuda else torch.zeros(
                            param.size())
            optimizer_D.step()

            #Print stats and generate images only if this is the server
            batches_done = epoch * len(train_set) + i
            if rank == 0 and batches_done % opt.sample_interval == 0:
                print(
                    "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f"
                    % (rank, epoch, opt.n_epochs, i, len(train_set),
                       d_loss.item(), g_loss.item(), time() - elapsed_time),
                    end=' ' if epoch != 0 else '\n')

                # Evaluation setp => output images and calculate FID
                if batches_done % opt.sample_interval == 0 and batches_done != 0:
                    fid_z = torch.randn(64, opt.latent_dim, 1, 1)
                    if cuda:
                        fid_z = fid_z.cuda()
                    del gen_imgs
                    gen_imgs = generator(fid_z)
                    mu_gen, sigma_gen = calculate_activation_statistics(
                        gen_imgs, fic_model)
                    mu_test, sigma_test = calculate_activation_statistics(
                        test_imgs[:opt.fid_batch], fic_model)
                    fid = calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_test, sigma_test)
                    print("FL-round {} FID Score: {}".format(fl_round, fid))
                    sys.stdout.flush()
Exemple #6
0
def run(rank, size):
    global fl_round
    global rat_per_class
    # Minimizes MSE
    adversarial_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)
    restart_count = 0
    epch = 0
    el_time = 0
    fl_rd = 0
    if (os.path.isfile(cp_path + "/checkpoint")):
        print("Conotinuing traing from a checkpoint")
        checkpoint = torch.load(cp_path + "/checkpoint")
        generator.load_state_dict(checkpoint['gen'])
        discriminator.load_state_dict(checkpoint['disc'])
        epch = checkpoint['epoch']
        el_time = checkpoint['time']
        fl_round = checkpoint['fl_round']
        restart_count = restart_count + 1
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Configure data loader
    same_data = False  #set this flag to True if all devices are required to hae the same data (not realistic; only for simulation)
    if same_data:
        os.makedirs("../data/mnist", exist_ok=True)
        train_set = torch.utils.data.DataLoader(
            datasets.MNIST(
                "../data/mnist",
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(opt.img_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]),
            ),
            batch_size=opt.batch_size,
            shuffle=True,
        )
    else:
        manager = DatasetManager(opt.model, opt.batch_size, opt.img_size,
                                 size - 1, size, rank, opt.iid, num_servers)
        train_set, _ = manager.get_train_set(opt.magic_num)

    lbl_count = [0 for _ in range(10)]
    for i, (imgs, lbls) in enumerate(train_set):
        for lbl in lbls:
            lbl_count[lbl.item()] += 1
    #This piece of info should be gathered at the server (to do informative decision about sampling)
    workers_classes = gather_lbl_count(lbl_count)
    if rank == 0:
        print(workers_classes)
    num_per_class = [
        5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949
    ]
    all_samples = sum(num_per_class)
    rat_per_class = [float(n / all_samples) for n in num_per_class]
    #Calculating entropy at this worker

    #Now, initializing all groups for the whole training process
    init_groups(size, workers_classes)
    print("Rank {} Done initializing {} groups".format(rank, len(all_groups)))
    #Calculating entropy of each worker (on the server side) based on these frequencies....
    if rank == 0:
        entropies = [
            stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) *
            (sum(freq_l) / all_samples) for freq_l in workers_classes
        ]
#        print("Entropies are: ", entropies)

# Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    print("cuda is there? ", cuda)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    #For FID calculations
    if rank == 0:
        fic_model = InceptionV3()
        if cuda:
            fic_model = fic_model.cuda()
        test_set = manager.get_test_set()
        for i, t in enumerate(test_set):
            test_imgs = t[0].cuda() if cuda else t[0]
            test_labels = t[1]
        grouped_test_imgages = [[] for i in range(10)]
        for i, img in enumerate(test_imgs):
            grouped_test_imgages[test_labels[i]].append(img)
        for i, arr in enumerate(grouped_test_imgages):
            grouped_test_imgages[i] = torch.stack(arr)
        print("just before training....server is talking")
        sys.stdout.flush()

    # ----------
    #  Training
    # ----------
    #DIST
    elapsed_time = time.time()
    num_batches = 0  #This variable acts as a global state variable to sync. between workers and the server
    done_round = True
    group = None
    #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue
    est_len = 50000 // (
        size * opt.batch_size
    )  #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers
    act_len = len(train_set)
    if act_len < est_len:
        opt.n_epochs = int(opt.n_epochs * (est_len / act_len))
    if rank == 0:
        print("Starting training...")
        sys.stdout.flush()
    epoch = 0
    while epoch < opt.n_epochs:
        if epoch == 0:
            epoch = epch  #Load the saved one in the checkpoint
        for i, (imgs, _) in enumerate(train_set):
            #DIST
            if done_round:  #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s)
                #In the beggining of each round, the primary server broadcasts the model to all other servers so that the model is kept safe in case of crash failure
                fl_round += 1
                g = all_groups_np[fl_round % len(all_groups)]
                group = all_groups[fl_round % len(all_groups)]
                choose_r0 = False
                if rank == 0:
                    choose_r0 = choose_r[fl_round % len(all_groups)]
                if rank in g:
                    broadcast_model(generator, group, elapsed_time)
                    broadcast_model(discriminator, group, elapsed_time)
                    done_round = False
                else:  #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server
                    done_round = True
                    num_batches = num_batches + opt.local_steps  #Advance the pointer for workers that will not work this round
                    continue
# uncomment the following lines to simualte/test server crash
#            if rank == 0:
#                if time.time() - elapsed_time > 500 and restart_count == 0:
#                    print("Crashing the server, first time..........................................")
#                    time.sleepp(1000)				#What about a software bug here ;)
            num_batches += 1
            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0),
                             requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0),
                            requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.latent_dim))))

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            d_gen = discriminator(gen_imgs)
            g_loss = adversarial_loss(d_gen, valid)

            g_loss.backward()

            #DIST
            #            g_avg_t = time()
            #Averaging step.......added because of distributed setup now!
            if num_batches % opt.local_steps == 0 and num_batches > 0:
                if opt.weight_avg:
                    #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker
                    cur_gp = all_groups_np[fl_round % len(all_groups)]
                    if rank == 0:
                        weights = [entropies[int(wrk)] for wrk in cur_gp]
                    else:  #dummy else
                        weights = [1.0 / len(cur_gp) for _ in cur_gp]
                    average_models(
                        generator,
                        group,
                        choose_r0,
                        weights,
                        elapsed_time=elapsed_time
                    )  #Experiments show that doing this is bad anyway!
                else:
                    average_models(generator,
                                   group,
                                   choose_r0,
                                   elapsed_time=elapsed_time)
                done_round = True

            if rank == 0 and not choose_r0:
                g_p = generator.parameters()
                for param in generator.parameters():
                    param.grad.data = torch.zeros(
                        param.size()).cuda() if cuda else torch.zeros(
                            param.size())

            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),
                                         fake)
            d_loss = 0.5 * (real_loss + fake_loss)

            d_loss.backward()

            #DIST
            #Averaging step.......added because of distributed setup now!
            if num_batches % opt.local_steps == 0 and num_batches > 0:
                if opt.weight_avg:
                    average_models(discriminator,
                                   group,
                                   choose_r0,
                                   weights,
                                   elapsed_time=elapsed_time)
                else:
                    average_models(discriminator,
                                   group,
                                   choose_r0,
                                   elapsed_time=elapsed_time)
                done_round = True
            if rank == 0 and not choose_r0:
                for param in discriminator.parameters():
                    param.grad.data = torch.zeros(
                        param.size()).cuda() if cuda else torch.zeros(
                            param.size())
            optimizer_D.step()

            #Print stats and generate images only if this is the server
            batches_done = epoch * len(train_set) + i
            if rank == 0 and batches_done % opt.sample_interval == 0:
                print(
                    "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f"
                    % (rank, epoch, opt.n_epochs, i, len(train_set),
                       d_loss.item(), g_loss.item(),
                       time.time() - elapsed_time + el_time),
                    end=' ' if epoch != 0 else '\n')

                # Evaluation setp => output images and calculate FID
                if batches_done % opt.sample_interval == 0 and batches_done != 0:
                    fid_z = Variable(
                        Tensor(
                            np.random.normal(0, 1,
                                             (opt.fid_batch, opt.latent_dim))))
                    del gen_imgs
                    gen_imgs = generator(fid_z)
                    mu_gen, sigma_gen = calculate_activation_statistics(
                        gen_imgs, fic_model)
                    mu_test, sigma_test = calculate_activation_statistics(
                        test_imgs[:opt.fid_batch], fic_model)
                    fid = calculate_frechet_distance(mu_gen, sigma_gen,
                                                     mu_test, sigma_test)
                    print("FL-round {} FID Score: {}".format(fl_round, fid))
                    sys.stdout.flush()
                    #For fault tolerance
                    print("saving checkpoint")
                    state = {
                        'disc': discriminator.state_dict(),
                        'gen': generator.state_dict(),
                        'epoch': epoch,
                        'time': time.time() - elapsed_time + el_time,
                        'fl_round': fl_round
                    }
                    torch.save(state, cp_path + "/checkpoint")
        epoch = epoch + 1
    with open(os.path.join('experimentation',  'cinvestav_testbed_experiment_results_' + str(seed)), 'rb') as f:
        results = cPickle.load(f)

    plot_cost(
        results=results,
        data_name='cost_train',
        plot_label='Cost on train phase')
    plot_cost(
        results=results,
        data_name='cost_valid',
        plot_label='Cost on valid phase')
    plot_cost(
        results=results,
        data_name='cost_test',
        plot_label='Cost on test phase')
    plt.show()

    """
    seed = 50
    dataset, result = DatasetManager.read_dataset2('test_cleaned_dataset.csv', shared=True, seed=seed)
    with open(os.path.join('trained_models',  'Logistic Regressionbrandeis_university.save'), 'rb') as f:
        model = cPickle.load(f)

    predicted_values = model.predict(dataset)
    get_metrics(
        test_set_y=result,
        predicted_values=predicted_values,
        model_name='Logistic Regression'
    )

Exemple #8
0
def run(rank, size):
    """ Distributed Synchronous SGD main function
	Args
	rank	Rank of the current process
	size	Total size of the world (num_workers + num_servers)
    """

    # Preparing hyper-parameters
    torch.manual_seed(1234)
    manager = DatasetManager(dataset, minibatch, num_workers, size, rank)
    train_set, bsz = manager.get_train_set()
    test_set = manager.get_test_set()
    if torch.cuda.device_count() > 0: # and rank >= num_ps:
      device = torch.device("cuda") #(rank-num_ps)%torch.cuda.device_count()))
    else:
      device = torch.device("cpu:0")
      print("CPU WARNING =====================================================================")

    print("Rank {} -> Device {}".format(rank, device))
    model = select_model(model_n, device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #for Cifar10, 0.001 and 0.9, MNIST: 0.01 and 0.5
    if model_n == 'resnet50':
      scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50], gamma=0.1)

    num_batches = ceil(len(train_set.dataset) / float(bsz))
    loss_fn = select_loss(loss_fn_n)
    g_l = [i for i in range(size)]
    world = dist.new_group(g_l)
    init_groups()
    #If PS are Byzantine, some subgroups of the world are required.....will be initialized as follows...
    print("-------------------------------- Rank {} have already done init groups...".format(rank))
    sys.stdout.flush()

    start_time = time.time()
    # Training loop
    print("One epoch has how many iterations: ", len(train_set))
    for epoch in range(epochs):
        epoch_loss = 0.0
        if model_n == 'resnet50':
          scheduler.step()
        model.train()
        for index, (data, target) in enumerate(train_set):
            if log:
              print("Rank {} Starting iteration {}".format(rank, index))
            train_time = time.time()
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)
            if rank >= num_ps:
                loss = loss_fn(output, target)
                loss.backward()
                epoch_loss += loss.item()
            if bench:
              print("Rank {} Train time {} ".format(rank, time.time() - train_time))
            if log:
              print("Rank {} Loop iteration {} Loss {}".format(rank,index, epoch_loss))
              sys.stdout.flush()
            reduce_time = time.time()
            reduce_gradients(model,rank, device, index)
            if bench:
              print("Rank {}, reduce time {} ".format(rank, time.time() - reduce_time))
            dist.barrier(world)
            optimizer.step()

        # Testing
        if rank < num_ps:
          test_time = time.time()
          acc = get_accuracy(model, test_set, device)
          print('Rank ', rank, ' epoch: ', epoch, ' acc: ', acc, "time: ", time.time() - start_time)
          print("Rank {}, test time {} ".format(rank, time.time() - test_time))
        else:
          print('Rank ', rank, 'epoch: ', epoch, 'loss: ', epoch_loss, "time: ", time.time() - start_time)
        sys.stdout.flush()
def theano_experiments():
    dataset_name = 'cinvestav_labeled.csv'
    seed = 5
    rgn = numpy.random.RandomState(seed)

    datasets = DatasetManager.read_dataset(
        dataset_name=os.path.join(os.path.dirname(__file__), 'dataset', 'meters', dataset_name),
        shared=True,
        seed=seed,
        expected_output=['result_x', 'result_y'],
        skipped_columns=[],
        label_encoding_columns_name=[],
        sklearn_preprocessing=preprocessing.StandardScaler(with_mean=True, with_std=True),
        sklearn_feature_selection=feature_selection.VarianceThreshold(),
        train_ratio=.8,
        test_ratio=0,
        valid_ratio=.2
    )

    test_set = DatasetManager.get_prediction_set(
        dataset_name=os.path.join(os.path.dirname(__file__), 'dataset', 'meters', 'cinvestav_labeled_test.csv'),
        expected_output=['result_x', 'result_y'],
        label_encoding_columns_name=[],
        skipped_columns=[],
        sklearn_preprocessing=datasets['sklearn_preprocessing'],
        sklearn_feature_selection=datasets['sklearn_feature_selection'],
        shared=True
    )

    dataset_unlabeled = DatasetManager.get_prediction_set(
        dataset_name=os.path.join(os.path.dirname(__file__), "dataset", 'cinvestav_unlabeled.csv'),
        skipped_columns=['result_x', 'result_y'],
        label_encoding_columns_name=[],
        sklearn_preprocessing=datasets['sklearn_preprocessing'],
        sklearn_feature_selection=datasets['sklearn_feature_selection'],
        shared=True
    )

    datasets['test_set'] = test_set
    datasets['dataset_unlabeled'] = dataset_unlabeled
    datasets['prediction_set'] = datasets['test_set'][0].get_value()
    train_set_x, train_set_y = datasets['train_set']

    n_in = train_set_x.get_value().shape[1]
    n_out = train_set_y.get_value().shape[1]

    dnn_tanh_models = get_neural_networks(
        n_in,
        n_out,
        rgn,
        activation_function=T.tanh  # T.nnet.relu
    )

    dnn_relu_models = get_neural_networks(
        n_in,
        n_out,
        rgn,
        activation_function=T.nnet.relu
    )

    dnn_sigmoid_models = get_neural_networks(
        n_in,
        n_out,
        rgn,
        activation_function=T.nnet.sigmoid
    )

    dbn_models = get_dbn(
        n_in,
        n_out,
        rgn,
        gaussian=False
    )

    gdbn_models = get_dbn(
        n_in,
        n_out,
        rgn,
        gaussian=True
    )

    models = []
    models.extend(dnn_relu_models)
    models.extend(dnn_sigmoid_models)
    models.extend(dnn_tanh_models)
    models.extend(gdbn_models)
    models.extend(dbn_models)

    params = {
        'learning_rate': .01,
        'annealing_learning_rate': .99999,
        'l1_learning_rate': 0.01,
        'l2_learning_rate': 0.001,
        'n_epochs': 2000,
        'batch_size': 20,
        'pre_training_epochs': 50,
        'pre_train_lr': 0.01,
        'k': 1,
        'datasets': datasets,
        'noise_rate': .1,
        'dropout_rate': None
    }

    run_theano_experiments(
        models=models,
        seed=seed,
        params=params,
        experiment_name='all_models_with_noise_without_dropout',
        task_type='regression'
    )
def sklearn_experiments():
    dataset_name = 'cinvestav_labeled.csv'
    seed = 5
    datasets = DatasetManager.read_dataset(
        dataset_name=os.path.join(os.path.dirname(__file__), "dataset", dataset_name),
        shared=False,
        seed=seed,
        expected_output=['result_x', 'result_y'],
        skipped_columns=[],
        label_encoding_columns_name=[],
        sklearn_preprocessing=preprocessing.StandardScaler(with_mean=True, with_std=True),
        sklearn_feature_selection=feature_selection.VarianceThreshold(),
        train_ratio=1,
        test_ratio=0,
        valid_ratio=0
    )

    test_set = DatasetManager.get_prediction_set(
        dataset_name=os.path.join(os.path.dirname(__file__), "dataset", 'cinvestav_labeled_test.csv'),
        expected_output=['result_x', 'result_y'],
        label_encoding_columns_name=[],
        skipped_columns=[],
        shared=False,
        sklearn_preprocessing=datasets['sklearn_preprocessing'],
        sklearn_feature_selection=datasets['sklearn_feature_selection'],
    )

    datasets['test_set'] = test_set
    datasets['prediction_set'] = datasets['test_set'][0]

    train_set_x, train_set_y = datasets['train_set']
    n_in = train_set_x.shape[1]
    n_out = train_set_y.shape[1]

    # Create Radial Basis Networks
    rbf = RBF(
        input_length=n_in,
        hidden_length=500,
        out_lenght=n_out
    )

    # Create KNN
    knn = SklearnNetwork(
        sklearn_model=KNeighborsRegressor(n_neighbors=10),
        num_output=n_out
    )

    # Create ada boosting
    ada_boosting = SklearnNetwork(
        sklearn_model=GradientBoostingRegressor(n_estimators=1000, learning_rate=.1, max_depth=5, loss='ls'),
        num_output=n_out
    )

    models = [
        ('Ada Boosting', ada_boosting),
        ('Radar', knn),
        ('cRBF', rbf)
    ]

    params = {
        'datasets': datasets
    }
    run_experiments_sklearn(
        models=models,
        seed=seed,
        params=params,
        experiment_name='traditional_algorithms',
        task_type='regression'
    )
Exemple #11
0
def run(rank, size):
    global fl_round
    global rat_per_class
    # !!! Minimizes MSE instead of BCE
    adversarial_loss = torch.nn.MSELoss()
#    adversarial_loss = torch.nn.BCELoss() #torch.nn.MSELoss #nn.BCELoss()
    # Initialize generator and discriminator
    generator = Generator() #(1)
    discriminator = Discriminator() #(1)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Configure data loader
#DIST (fix the path of data)
    manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size-1, size, rank, opt.iid)
    train_set, _ = manager.get_train_set(opt.max_samples)
    init_groups(size)
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    #For FID calculations
    if rank == 0:
        fic_model = InceptionV3()
        if cuda:
            fic_model = fic_model.cuda()
        test_set = manager.get_test_set()
        for i,t in enumerate(test_set):
            test_imgs = t[0].cuda()
            test_labels = t[1]

    # ----------
    #  Training
    # ----------
    #DIST
    elapsed_time = time()
    num_batches=0		#This variable acts as a global state variable to sync. between workers and the server
    done_round = True
    group = None
    #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue
    est_len = 50000 // (size * opt.batch_size)		#Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers
    act_len = len(train_set)
    if act_len < est_len:
        opt.n_epochs = int(opt.n_epochs * (est_len/act_len))
    imgs = []
#    print("Rank {}  just before the training loop....".format(rank))
    for i, (tmps,_) in enumerate(train_set):		#hack to get only one image
        imgs=tmps
        break
    for epoch in range(opt.n_epochs):
        broadcast_model(generator, elapsed_time=elapsed_time)
        fl_round+=1
        num_batches+=1
            # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        #HINT: training the generator is not required on the server, yet I am doing it only because PyTorch requires it. It does not affect the runtime anyway
            # -----------------
            #  Train Generator
            # -----------------
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
#        z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda()
        temp = generator(z)
        if rank == 0:		#MD-GAN trains the generator only on the server
            optimizer_G.zero_grad()
            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
#            z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda()
            # Generate a batch of images
            X_g = generator(z)
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
#            z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda()
            # Generate a batch of images
            X_d = generator(z)
            for n in range(size-1):
                dist.broadcast(tensor=X_g, src=0, group=all_groups[n])
                dist.broadcast(tensor=X_d, src=0,group=all_groups[n])

        else: #First, workers receive generated batches by the server 
            X_g = torch.zeros(temp.size())
            X_d = torch.zeros(temp.size())
            dist.broadcast(tensor=X_g, src=0, group=all_groups[rank-1])
            dist.broadcast(tensor=X_d, src=0, group=all_groups[rank-1])
            X_g = X_g.cuda()
            X_d = X_d.cuda()

            # Loss measures generator's ability to fool the discriminator
        if rank == 0:
            d_gen = discriminator(temp)
            g_loss = adversarial_loss(d_gen, valid)
            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

#            disc_t = time()
        if rank != 0:
            L = 12								#This is a parameter by MD-GAN. A worker should only do L iterations.
            for iter, (imgs_t, _) in enumerate(train_set):
                real_imgs = Variable(imgs_t.type(Tensor))
                if real_imgs.size()[0] != opt.batch_size:			#To avoid mismatch problems
                    continue
                optimizer_D.zero_grad()

                # Measure discriminator's ability to classify real from generated samples
                real_loss = adversarial_loss(discriminator(real_imgs), valid)
                fake_loss = adversarial_loss(discriminator(X_d.detach()), fake)
                d_loss = 0.5 * (real_loss + fake_loss)
                d_loss.backward()
                optimizer_D.step()
#                print("process {} iter {}".format(rank,iter))
                if iter == L-1:
                    break

            optimizer_G.zero_grad()
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
#            z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda()
            X_g = generator(z)
            g_loss = adversarial_loss(discriminator(X_g), valid)
            g_loss.backward()
            optimizer_G.step()
        average_models(generator, elapsed_time=elapsed_time)
        del X_g
        del X_d
            #Print stats and generate images only if this is the server
        batches_done = fl_round #epoch * len(train_set) + i
        if rank == 0 and fl_round%20 == 0:
            print(
                "Rank %d [Epoch %d/%d] [Batch %d/%d] time %f"
                % (rank, epoch, opt.n_epochs, i, len(train_set), time() - elapsed_time), 
                end = ' ' if epoch != 0 else '\n'
            )
#                sys.stdout.flush()

                # Evaluation setp => output images and calculate FID
#                if batches_done % opt.sample_interval == 0 and batches_done != 0:
#                    pathname = os.path.abspath(os.path.dirname(sys.argv[0]))
#                    save_image(gen_imgs.data[:25], pathname+"/images-dist-s{}-w{}/{}-{}.png".format(opt.sample, opt.weight_avg, rank,batches_done), nrow=5, normalize=True)
#                    print("=====Calculating FID for round {}======".format(fl_round))
            fid_z = Variable(Tensor(np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim))))
            gen_imgs = generator(fid_z)
            mu_gen, sigma_gen = calculate_activation_statistics(gen_imgs, fic_model)
            mu_test, sigma_test = calculate_activation_statistics(test_imgs[:opt.fid_batch], fic_model)
            fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test)
#            fid = 3000
            print("FL-round {} FID Score: {}".format(fl_round, fid))
            sys.stdout.flush()