def run(rank, size): global fl_round global rat_per_class # Minimizes MSE adversarial_loss = torch.nn.MSELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # Configure data loader same_data = False #set to True if all devices are required to hold the same data if same_data: os.makedirs("../data/mnist", exist_ok=True) train_set = torch.utils.data.DataLoader( datasets.MNIST( "../data/mnist", train=True, download=True, transform=transforms.Compose([ transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]), ), batch_size=opt.batch_size, shuffle=True, ) else: manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1, size, rank, opt.iid, 1) train_set, _ = manager.get_train_set(opt.magic_num) init_groups(size) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor #For FID calculations if rank == 0: fic_model = InceptionV3() if cuda: fic_model = fic_model.cuda() test_set = manager.get_test_set() for i, t in enumerate(test_set): test_imgs = t[0].cuda() if cuda else t[0] test_labels = t[1] # ---------- # Training # ---------- #DIST elapsed_time = time() num_batches = 0 #This variable acts as a global state variable to sync. between workers and the server done_round = True group = None #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue est_len = 50000 // ( size * opt.batch_size ) #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers act_len = len(train_set) if act_len < est_len: opt.n_epochs = int(opt.n_epochs * (est_len / act_len)) imgs = [] for i, (tmps, _) in enumerate(train_set): imgs = tmps break for epoch in range(opt.n_epochs): broadcast_model(generator, elapsed_time=elapsed_time) fl_round += 1 num_batches += 1 # Adversarial ground truths valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) #HINT: training the generator is not required on the server, yet PyTorch requires it. It does not affect the runtime anyway # ----------------- # Train Generator # ----------------- z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) temp = generator(z) if rank == 0: #MD-GAN trains the generator only on the server optimizer_G.zero_grad() # Sample noise as generator input z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # Generate a batch of images X_g = generator(z) z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # Generate a batch of images X_d = generator(z) for n in range(size - 1): # Sample noise as generator input # Generate a batch of images dist.broadcast(tensor=X_g, src=0, group=all_groups[n]) # Generate a batch of images dist.broadcast(tensor=X_d, src=0, group=all_groups[n]) else: #First, workers receive generated batches by the server X_g = torch.zeros(temp.size()) X_d = torch.zeros(temp.size()) dist.broadcast(tensor=X_g, src=0, group=all_groups[rank - 1]) dist.broadcast(tensor=X_d, src=0, group=all_groups[rank - 1]) if cuda: X_g = X_g.cuda() X_d = X_d.cuda() # Loss measures generator's ability to fool the discriminator if rank == 0: d_gen = discriminator(temp) g_loss = adversarial_loss(d_gen, valid) g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- if rank != 0: L = 12 for iter, (imgs_t, _) in enumerate(train_set): real_imgs = Variable(imgs_t.type(Tensor)) if real_imgs.size( )[0] != opt.batch_size: #To avoid mismatch problems continue optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(X_d.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() optimizer_D.step() if iter == L - 1: break optimizer_G.zero_grad() z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) X_g = generator(z) g_loss = adversarial_loss(discriminator(X_g), valid) g_loss.backward() optimizer_G.step() average_models(generator, elapsed_time=elapsed_time) del X_g del X_d #Print stats and generate images only if this is the server batches_done = fl_round if rank == 0 and fl_round % 20 == 0: print("Rank %d [Epoch %d/%d] [Batch %d/%d] time %f" % (rank, epoch, opt.n_epochs, i, len(train_set), time() - elapsed_time), end=' ' if epoch != 0 else '\n') fid_z = Variable( Tensor(np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))) gen_imgs = generator(fid_z) mu_gen, sigma_gen = calculate_activation_statistics( gen_imgs, fic_model) mu_test, sigma_test = calculate_activation_statistics( test_imgs[:opt.fid_batch], fic_model) fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test) print("FL-round {} FID Score: {}".format(fl_round, fid)) sys.stdout.flush()
def run(rank, size): global fl_round global rat_per_class # !!! Minimizes MSE instead of BCE adversarial_loss = torch.nn.MSELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # Configure data loader #DIST (fix the path of data) manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1, size, rank, opt.iid) train_set, _ = manager.get_train_set(opt.max_samples) lbl_count = [0 for _ in range(10)] for i, (imgs, lbls) in enumerate(train_set): for lbl in lbls: lbl_count[lbl.item()] += 1 #This piece of info should be gathered at the server (to do informative decision about sampling) workers_classes = gather_lbl_count(lbl_count) if rank == 0: print(workers_classes) num_per_class = [ 5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949 ] #Aggregate number of classes is calculated manually here all_samples = sum(num_per_class) rat_per_class = [float(n / all_samples) for n in num_per_class] #Calculating entropy at this worker #Now, initializing all groups for the whole training process # gp_t = time() init_groups(size, workers_classes) print("Rank {} Done initializing {} groups".format(rank, len(all_groups))) # if opt.bench: # print("Time to init the groups: ", time() - gp_t) #Calculating entropy of each worker (on the server side) based on these frequencies.... if rank == 0: entropies = [ stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) * (sum(freq_l) / all_samples) for freq_l in workers_classes ] print("Entropies are: ", entropies) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor #For FID calculations if rank == 0: fic_model = InceptionV3() if cuda: fic_model = fic_model.cuda() test_set = manager.get_test_set() for i, t in enumerate(test_set): test_imgs = t[0].cuda() test_labels = t[1] grouped_test_imgages = [[] for i in range(10)] for i, img in enumerate(test_imgs): grouped_test_imgages[test_labels[i]].append(img) for i, arr in enumerate(grouped_test_imgages): grouped_test_imgages[i] = torch.stack(arr) # ---------- # Training # ---------- #DIST elapsed_time = time() num_batches = 0 #This variable acts as a global state variable to sync. between workers and the server done_round = True group = None #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue est_len = 50000 // ( size * opt.batch_size ) #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers act_len = len(train_set) if act_len < est_len: opt.n_epochs = int(opt.n_epochs * (est_len / act_len)) for epoch in range(opt.n_epochs): for i, (imgs, _) in enumerate(train_set): #DIST if done_round: #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s) #First step: Choose a group of nodes to do computations in this round.... fl_round += 1 g = all_groups_np[fl_round % len(all_groups)] group = all_groups[fl_round % len(all_groups)] choose_r0 = False if rank == 0: choose_r0 = choose_r[fl_round % len(all_groups)] # broad_t = time() if rank in g: broadcast_model(generator, group, elapsed_time) broadcast_model(discriminator, group, elapsed_time) done_round = False else: #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server done_round = True num_batches = num_batches + opt.local_steps #Advance the pointer for workers that will not work this round continue num_batches += 1 # Adversarial ground truths valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(Tensor)) # ----------------- # Train Generator # ----------------- # gen_t = time() optimizer_G.zero_grad() # Sample noise as generator input z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator # gd_t = time() d_gen = discriminator(gen_imgs) g_loss = adversarial_loss(d_gen, valid) g_loss.backward() # if opt.bench and rank == 0: # print("Time of bakward pass 1 for discriminator ", time() - gd_t) #DIST # g_avg_t = time() #Averaging step.......added because of distributed setup now! if num_batches % opt.local_steps == 0 and num_batches > 0: if opt.weight_avg: #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker cur_gp = all_groups_np[fl_round % len(all_groups)] if rank == 0: weights = [entropies[int(wrk)] for wrk in cur_gp] else: #dummy else weights = [1.0 / len(cur_gp) for _ in cur_gp] average_models( generator, group, choose_r0, weights, elapsed_time=elapsed_time ) #Experiments show that doing this is bad anyway! else: average_models(generator, group, choose_r0, elapsed_time=elapsed_time) done_round = True if rank == 0 and not choose_r0: g_p = generator.parameters() for param in generator.parameters(): param.grad.data = torch.zeros(param.size()).cuda() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- # disc_t = time() optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() #DIST #Averaging step.......added because of distributed setup now! # d_avg_t = time() if num_batches % opt.local_steps == 0 and num_batches > 0: if opt.weight_avg: average_models(discriminator, group, choose_r0, weights, elapsed_time=elapsed_time) else: average_models(discriminator, group, choose_r0, elapsed_time=elapsed_time) done_round = True if rank == 0 and not choose_r0: for param in discriminator.parameters(): param.grad.data = torch.zeros(param.size()).cuda() optimizer_D.step() #Print stats and generate images only if this is the server batches_done = epoch * len(train_set) + i if rank == 0 and batches_done % opt.sample_interval == 0: print( "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f" % (rank, epoch, opt.n_epochs, i, len(train_set), d_loss.item(), g_loss.item(), time() - elapsed_time), end=' ' if epoch != 0 else '\n') # sys.stdout.flush() # Evaluation setp => output images and calculate FID if batches_done % opt.sample_interval == 0 and batches_done != 0: # pathname = os.path.abspath(os.path.dirname(sys.argv[0])) # save_image(gen_imgs.data[:25], pathname+"/images-dist-s{}-w{}/{}-{}.png".format(opt.sample, opt.weight_avg, rank,batches_done), nrow=5, normalize=True) # print("=====Calculating FID for round {}======".format(fl_round)) fid_z = Variable( Tensor( np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))) del gen_imgs gen_imgs = generator(fid_z) mu_gen, sigma_gen = calculate_activation_statistics( gen_imgs, fic_model) mu_test, sigma_test = calculate_activation_statistics( test_imgs[:opt.fid_batch], fic_model) fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test) print("FL-round {} FID Score: {}".format(fl_round, fid)) sys.stdout.flush() if False: #not opt.iid: cur = 0 fids = [0 for i in range(10)] for i, gp in enumerate(grouped_test_imgages): mu_gen, sigma_gen = calculate_activation_statistics( gen_imgs[cur:cur + len(gp)], fic_model) cur += len(gp) mu_test, sigma_test = calculate_activation_statistics( gp, fic_model) fids[i] = calculate_frechet_distance( mu_gen, sigma_gen, mu_test, sigma_test) print("avg: ", np.mean(fids), " max: ", np.max(fids), " min: ", np.min(fids))
def run(rank, size): global fl_round global rat_per_class NUM_CLASSES = 200 if opt.model == 'imagenet' else 10 criterion = torch.nn.BCELoss() # Create batch of latent vectors that we will use to visualize # the progression of the generator fixed_noise = torch.randn(opt.batch_size, opt.latent_dim, 1, 1) if cuda: fixed_noise = fixed_noise.cuda() # Initialize generator and discriminator generator = Generator(1) generator.apply(weights_init) discriminator = Discriminator(1) discriminator.apply(weights_init) if cuda: generator.cuda() discriminator.cuda() criterion.cuda() # Configure data loader #DIST manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1, size, rank, opt.iid, 1) train_set, _ = manager.get_train_set(opt.magic_num) lbl_count = [0 for _ in range(NUM_CLASSES)] all_labels = [] for i, (imgs, lbls) in enumerate(train_set): for lbl in lbls: if lbl.item() not in all_labels: all_labels.append(lbl.item()) lbl_count[lbl.item()] += 1 workers_classes = gather_lbl_count(lbl_count) num_per_class = [500 for _ in range(NUM_CLASSES)] all_samples = sum(num_per_class) rat_per_class = [float(n / all_samples) for n in num_per_class] #Calculating entropy at this worker ent = stats.entropy(np.array(lbl_count) / sum(lbl_count), rat_per_class) #Now, initializing all groups for the whole training process print("Rank {} Start init groups".format(rank)) sys.stdout.flush() init_groups(size, workers_classes) print("Rank {} Done initializing {} groups".format(rank, len(all_groups))) #Calculating entropy of each worker (on the server side) based on these frequencies.... if rank == 0 and opt.weight_avg: entropies = [ stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) * (sum(freq_l) / all_samples) for freq_l in workers_classes ] # print("Entropies are: ", entropies) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor #For FID calculations if rank == 0: fic_model = InceptionV3() if cuda: fic_model = fic_model.cuda() test_set = manager.get_test_set() for i, t in enumerate(test_set): test_imgs = t[0].cuda() if cuda else t[0] test_labels = t[1] # ---------- # Training # ---------- #DIST elapsed_time = time() weak_workers = [] if weak_percent > 0.0: weak_workers = [i for i in range(1, size, round(1 / weak_percent))] print("Number of simulated weak workers: ", len(weak_workers)) num_batches = 0 #This variable acts as a global state variable to sync. between workers and the server done_round = True group = None #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue est_len = 1000000 // ( size * opt.batch_size ) #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers act_len = len(train_set) if act_len < est_len: opt.n_epochs = int(opt.n_epochs * (est_len / act_len)) for epoch in range(opt.n_epochs): for i, (imgs, _) in enumerate(train_set): #DIST if done_round: #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s) #First step: Choose a group of nodes to do computations in this round.... fl_round += 1 g = all_groups_np[fl_round % len(all_groups)] group = all_groups[fl_round % len(all_groups)] choose_r0 = False if rank == 0: choose_r0 = choose_r[fl_round % len(all_groups)] if rank in g: broadcast_model(generator, group) broadcast_model(discriminator, group) done_round = False else: #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server done_round = True num_batches = num_batches + opt.local_steps #Advance the pointer for workers that will not work this round continue # Adversarial ground truths real_imgs = Variable(imgs.type(Tensor)) valid = Variable(Tensor(real_imgs.size()[0], 1, 1, 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(real_imgs.size()[0], 1, 1, 1).fill_(0.0), requires_grad=False) num_batches += 1 # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Sample noise as generator input z = torch.randn(real_imgs.size()[0], opt.latent_dim, 1, 1) if cuda: z = z.cuda() # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator d_gen = discriminator(gen_imgs) g_loss = criterion(d_gen, valid) g_loss.backward() #DIST #Averaging step.......added because of distributed setup now! local_steps = opt.local_steps if rank in weak_workers: local_steps = int(opt.local_steps / 2) if num_batches % local_steps == 0 and num_batches > 0: if opt.weight_avg: #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker cur_gp = all_groups_np[fl_round % len(all_groups)] if rank == 0: weights = [entropies[int(wrk)] for wrk in cur_gp] else: #dummy else weights = [1.0 / len(cur_gp) for _ in cur_gp] #This weighting is orthogonal to KL-weighting scheme average_models(generator, group, choose_r0, weights) done_round = True if rank == 0 and not choose_r0: g_p = generator.parameters() for param in generator.parameters(): param.grad.data = torch.zeros( param.size()).cuda() if cuda else torch.zeros( param.size()) optimizer_G.step() if rank == 0 and not choose_r0: for o, n in zip(g_p, generator.parameters()): if not torch.eq(o, n).all(): print( "Generator updated while it should not have been!!!! error here......." ) # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = criterion(discriminator(real_imgs), valid) fake_loss = criterion(discriminator(gen_imgs.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() #DIST #Averaging step.......added because of distributed setup now! if num_batches % local_steps == 0 and num_batches > 0: #In the new version, we apply weights anyway.....to account for weak workers not only KL-divergence average_models(discriminator, group, choose_r0, weights) done_round = True if rank == 0 and not choose_r0: for param in discriminator.parameters(): param.grad.data = torch.zeros( param.size()).cuda() if cuda else torch.zeros( param.size()) optimizer_D.step() #Print stats and generate images only if this is the server batches_done = epoch * len(train_set) + i if rank == 0 and batches_done % opt.sample_interval == 0: print( "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f" % (rank, epoch, opt.n_epochs, i, len(train_set), d_loss.item(), g_loss.item(), time() - elapsed_time), end=' ' if epoch != 0 else '\n') # Evaluation setp => output images and calculate FID if batches_done % opt.sample_interval == 0 and batches_done != 0: fid_z = torch.randn(64, opt.latent_dim, 1, 1) if cuda: fid_z = fid_z.cuda() del gen_imgs gen_imgs = generator(fid_z) mu_gen, sigma_gen = calculate_activation_statistics( gen_imgs, fic_model) mu_test, sigma_test = calculate_activation_statistics( test_imgs[:opt.fid_batch], fic_model) fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test) print("FL-round {} FID Score: {}".format(fl_round, fid)) sys.stdout.flush()
def run(rank, size): global fl_round global rat_per_class # Minimizes MSE adversarial_loss = torch.nn.MSELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) restart_count = 0 epch = 0 el_time = 0 fl_rd = 0 if (os.path.isfile(cp_path + "/checkpoint")): print("Conotinuing traing from a checkpoint") checkpoint = torch.load(cp_path + "/checkpoint") generator.load_state_dict(checkpoint['gen']) discriminator.load_state_dict(checkpoint['disc']) epch = checkpoint['epoch'] el_time = checkpoint['time'] fl_round = checkpoint['fl_round'] restart_count = restart_count + 1 if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Configure data loader same_data = False #set this flag to True if all devices are required to hae the same data (not realistic; only for simulation) if same_data: os.makedirs("../data/mnist", exist_ok=True) train_set = torch.utils.data.DataLoader( datasets.MNIST( "../data/mnist", train=True, download=True, transform=transforms.Compose([ transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]), ), batch_size=opt.batch_size, shuffle=True, ) else: manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size - 1, size, rank, opt.iid, num_servers) train_set, _ = manager.get_train_set(opt.magic_num) lbl_count = [0 for _ in range(10)] for i, (imgs, lbls) in enumerate(train_set): for lbl in lbls: lbl_count[lbl.item()] += 1 #This piece of info should be gathered at the server (to do informative decision about sampling) workers_classes = gather_lbl_count(lbl_count) if rank == 0: print(workers_classes) num_per_class = [ 5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949 ] all_samples = sum(num_per_class) rat_per_class = [float(n / all_samples) for n in num_per_class] #Calculating entropy at this worker #Now, initializing all groups for the whole training process init_groups(size, workers_classes) print("Rank {} Done initializing {} groups".format(rank, len(all_groups))) #Calculating entropy of each worker (on the server side) based on these frequencies.... if rank == 0: entropies = [ stats.entropy(np.array(freq_l) / sum(freq_l), rat_per_class) * (sum(freq_l) / all_samples) for freq_l in workers_classes ] # print("Entropies are: ", entropies) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) print("cuda is there? ", cuda) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor #For FID calculations if rank == 0: fic_model = InceptionV3() if cuda: fic_model = fic_model.cuda() test_set = manager.get_test_set() for i, t in enumerate(test_set): test_imgs = t[0].cuda() if cuda else t[0] test_labels = t[1] grouped_test_imgages = [[] for i in range(10)] for i, img in enumerate(test_imgs): grouped_test_imgages[test_labels[i]].append(img) for i, arr in enumerate(grouped_test_imgages): grouped_test_imgages[i] = torch.stack(arr) print("just before training....server is talking") sys.stdout.flush() # ---------- # Training # ---------- #DIST elapsed_time = time.time() num_batches = 0 #This variable acts as a global state variable to sync. between workers and the server done_round = True group = None #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue est_len = 50000 // ( size * opt.batch_size ) #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers act_len = len(train_set) if act_len < est_len: opt.n_epochs = int(opt.n_epochs * (est_len / act_len)) if rank == 0: print("Starting training...") sys.stdout.flush() epoch = 0 while epoch < opt.n_epochs: if epoch == 0: epoch = epch #Load the saved one in the checkpoint for i, (imgs, _) in enumerate(train_set): #DIST if done_round: #This means that a new round should start....done by sampling a few of workers and give them the latest version of the model(s) #In the beggining of each round, the primary server broadcasts the model to all other servers so that the model is kept safe in case of crash failure fl_round += 1 g = all_groups_np[fl_round % len(all_groups)] group = all_groups[fl_round % len(all_groups)] choose_r0 = False if rank == 0: choose_r0 = choose_r[fl_round % len(all_groups)] if rank in g: broadcast_model(generator, group, elapsed_time) broadcast_model(discriminator, group, elapsed_time) done_round = False else: #This node is not chosen in the current group....no work for this node in this round....just continue and wait for a new announcement from the server done_round = True num_batches = num_batches + opt.local_steps #Advance the pointer for workers that will not work this round continue # uncomment the following lines to simualte/test server crash # if rank == 0: # if time.time() - elapsed_time > 500 and restart_count == 0: # print("Crashing the server, first time..........................................") # time.sleepp(1000) #What about a software bug here ;) num_batches += 1 # Adversarial ground truths valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(Tensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Sample noise as generator input z = Variable( Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator d_gen = discriminator(gen_imgs) g_loss = adversarial_loss(d_gen, valid) g_loss.backward() #DIST # g_avg_t = time() #Averaging step.......added because of distributed setup now! if num_batches % opt.local_steps == 0 and num_batches > 0: if opt.weight_avg: #This is a weighting scheme using the entropies based on the frequency of samples of each class at each worker cur_gp = all_groups_np[fl_round % len(all_groups)] if rank == 0: weights = [entropies[int(wrk)] for wrk in cur_gp] else: #dummy else weights = [1.0 / len(cur_gp) for _ in cur_gp] average_models( generator, group, choose_r0, weights, elapsed_time=elapsed_time ) #Experiments show that doing this is bad anyway! else: average_models(generator, group, choose_r0, elapsed_time=elapsed_time) done_round = True if rank == 0 and not choose_r0: g_p = generator.parameters() for param in generator.parameters(): param.grad.data = torch.zeros( param.size()).cuda() if cuda else torch.zeros( param.size()) optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() #DIST #Averaging step.......added because of distributed setup now! if num_batches % opt.local_steps == 0 and num_batches > 0: if opt.weight_avg: average_models(discriminator, group, choose_r0, weights, elapsed_time=elapsed_time) else: average_models(discriminator, group, choose_r0, elapsed_time=elapsed_time) done_round = True if rank == 0 and not choose_r0: for param in discriminator.parameters(): param.grad.data = torch.zeros( param.size()).cuda() if cuda else torch.zeros( param.size()) optimizer_D.step() #Print stats and generate images only if this is the server batches_done = epoch * len(train_set) + i if rank == 0 and batches_done % opt.sample_interval == 0: print( "Rank %d [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time %f" % (rank, epoch, opt.n_epochs, i, len(train_set), d_loss.item(), g_loss.item(), time.time() - elapsed_time + el_time), end=' ' if epoch != 0 else '\n') # Evaluation setp => output images and calculate FID if batches_done % opt.sample_interval == 0 and batches_done != 0: fid_z = Variable( Tensor( np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))) del gen_imgs gen_imgs = generator(fid_z) mu_gen, sigma_gen = calculate_activation_statistics( gen_imgs, fic_model) mu_test, sigma_test = calculate_activation_statistics( test_imgs[:opt.fid_batch], fic_model) fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test) print("FL-round {} FID Score: {}".format(fl_round, fid)) sys.stdout.flush() #For fault tolerance print("saving checkpoint") state = { 'disc': discriminator.state_dict(), 'gen': generator.state_dict(), 'epoch': epoch, 'time': time.time() - elapsed_time + el_time, 'fl_round': fl_round } torch.save(state, cp_path + "/checkpoint") epoch = epoch + 1
def run(rank, size): """ Distributed Synchronous SGD main function Args rank Rank of the current process size Total size of the world (num_workers + num_servers) """ # Preparing hyper-parameters torch.manual_seed(1234) manager = DatasetManager(dataset, minibatch, num_workers, size, rank) train_set, bsz = manager.get_train_set() test_set = manager.get_test_set() if torch.cuda.device_count() > 0: # and rank >= num_ps: device = torch.device("cuda") #(rank-num_ps)%torch.cuda.device_count())) else: device = torch.device("cpu:0") print("CPU WARNING =====================================================================") print("Rank {} -> Device {}".format(rank, device)) model = select_model(model_n, device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #for Cifar10, 0.001 and 0.9, MNIST: 0.01 and 0.5 if model_n == 'resnet50': scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50], gamma=0.1) num_batches = ceil(len(train_set.dataset) / float(bsz)) loss_fn = select_loss(loss_fn_n) g_l = [i for i in range(size)] world = dist.new_group(g_l) init_groups() #If PS are Byzantine, some subgroups of the world are required.....will be initialized as follows... print("-------------------------------- Rank {} have already done init groups...".format(rank)) sys.stdout.flush() start_time = time.time() # Training loop print("One epoch has how many iterations: ", len(train_set)) for epoch in range(epochs): epoch_loss = 0.0 if model_n == 'resnet50': scheduler.step() model.train() for index, (data, target) in enumerate(train_set): if log: print("Rank {} Starting iteration {}".format(rank, index)) train_time = time.time() optimizer.zero_grad() data, target = data.to(device), target.to(device) output = model(data) if rank >= num_ps: loss = loss_fn(output, target) loss.backward() epoch_loss += loss.item() if bench: print("Rank {} Train time {} ".format(rank, time.time() - train_time)) if log: print("Rank {} Loop iteration {} Loss {}".format(rank,index, epoch_loss)) sys.stdout.flush() reduce_time = time.time() reduce_gradients(model,rank, device, index) if bench: print("Rank {}, reduce time {} ".format(rank, time.time() - reduce_time)) dist.barrier(world) optimizer.step() # Testing if rank < num_ps: test_time = time.time() acc = get_accuracy(model, test_set, device) print('Rank ', rank, ' epoch: ', epoch, ' acc: ', acc, "time: ", time.time() - start_time) print("Rank {}, test time {} ".format(rank, time.time() - test_time)) else: print('Rank ', rank, 'epoch: ', epoch, 'loss: ', epoch_loss, "time: ", time.time() - start_time) sys.stdout.flush()
def run(rank, size): global fl_round global rat_per_class # !!! Minimizes MSE instead of BCE adversarial_loss = torch.nn.MSELoss() # adversarial_loss = torch.nn.BCELoss() #torch.nn.MSELoss #nn.BCELoss() # Initialize generator and discriminator generator = Generator() #(1) discriminator = Discriminator() #(1) if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # Configure data loader #DIST (fix the path of data) manager = DatasetManager(opt.model, opt.batch_size, opt.img_size, size-1, size, rank, opt.iid) train_set, _ = manager.get_train_set(opt.max_samples) init_groups(size) optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor #For FID calculations if rank == 0: fic_model = InceptionV3() if cuda: fic_model = fic_model.cuda() test_set = manager.get_test_set() for i,t in enumerate(test_set): test_imgs = t[0].cuda() test_labels = t[1] # ---------- # Training # ---------- #DIST elapsed_time = time() num_batches=0 #This variable acts as a global state variable to sync. between workers and the server done_round = True group = None #The following hack (4 lines) is written to run actually the number of runs that the user is aiming for....because of the skewness of data, the actual number of epochs that would run could be less than that the user is estimating...These few lines solve this issue est_len = 50000 // (size * opt.batch_size) #Given a dataset of 50,000 imgaes, the estimated number of iterations to dataset is 50000/unm_workers act_len = len(train_set) if act_len < est_len: opt.n_epochs = int(opt.n_epochs * (est_len/act_len)) imgs = [] # print("Rank {} just before the training loop....".format(rank)) for i, (tmps,_) in enumerate(train_set): #hack to get only one image imgs=tmps break for epoch in range(opt.n_epochs): broadcast_model(generator, elapsed_time=elapsed_time) fl_round+=1 num_batches+=1 # Adversarial ground truths valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) #HINT: training the generator is not required on the server, yet I am doing it only because PyTorch requires it. It does not affect the runtime anyway # ----------------- # Train Generator # ----------------- z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda() temp = generator(z) if rank == 0: #MD-GAN trains the generator only on the server optimizer_G.zero_grad() # Sample noise as generator input z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda() # Generate a batch of images X_g = generator(z) z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda() # Generate a batch of images X_d = generator(z) for n in range(size-1): dist.broadcast(tensor=X_g, src=0, group=all_groups[n]) dist.broadcast(tensor=X_d, src=0,group=all_groups[n]) else: #First, workers receive generated batches by the server X_g = torch.zeros(temp.size()) X_d = torch.zeros(temp.size()) dist.broadcast(tensor=X_g, src=0, group=all_groups[rank-1]) dist.broadcast(tensor=X_d, src=0, group=all_groups[rank-1]) X_g = X_g.cuda() X_d = X_d.cuda() # Loss measures generator's ability to fool the discriminator if rank == 0: d_gen = discriminator(temp) g_loss = adversarial_loss(d_gen, valid) g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- # disc_t = time() if rank != 0: L = 12 #This is a parameter by MD-GAN. A worker should only do L iterations. for iter, (imgs_t, _) in enumerate(train_set): real_imgs = Variable(imgs_t.type(Tensor)) if real_imgs.size()[0] != opt.batch_size: #To avoid mismatch problems continue optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(X_d.detach()), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() optimizer_D.step() # print("process {} iter {}".format(rank,iter)) if iter == L-1: break optimizer_G.zero_grad() z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # z = torch.randn(imgs.shape[0], opt.latent_dim, 1, 1).cuda() X_g = generator(z) g_loss = adversarial_loss(discriminator(X_g), valid) g_loss.backward() optimizer_G.step() average_models(generator, elapsed_time=elapsed_time) del X_g del X_d #Print stats and generate images only if this is the server batches_done = fl_round #epoch * len(train_set) + i if rank == 0 and fl_round%20 == 0: print( "Rank %d [Epoch %d/%d] [Batch %d/%d] time %f" % (rank, epoch, opt.n_epochs, i, len(train_set), time() - elapsed_time), end = ' ' if epoch != 0 else '\n' ) # sys.stdout.flush() # Evaluation setp => output images and calculate FID # if batches_done % opt.sample_interval == 0 and batches_done != 0: # pathname = os.path.abspath(os.path.dirname(sys.argv[0])) # save_image(gen_imgs.data[:25], pathname+"/images-dist-s{}-w{}/{}-{}.png".format(opt.sample, opt.weight_avg, rank,batches_done), nrow=5, normalize=True) # print("=====Calculating FID for round {}======".format(fl_round)) fid_z = Variable(Tensor(np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))) gen_imgs = generator(fid_z) mu_gen, sigma_gen = calculate_activation_statistics(gen_imgs, fic_model) mu_test, sigma_test = calculate_activation_statistics(test_imgs[:opt.fid_batch], fic_model) fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test) # fid = 3000 print("FL-round {} FID Score: {}".format(fl_round, fid)) sys.stdout.flush()