def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_i.load_state_dict(state_dict['i']) self.gen_r.load_state_dict(state_dict['r']) self.gen_s.load_state_dict(state_dict['s']) if self.with_mapping: self.fea_m.load_state_dict(state_dict['fm']) self.fea_s.load_state_dict(state_dict['fs']) self.best_result = state_dict['best_result'] iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_r.load_state_dict(state_dict['r']) self.dis_s.load_state_dict(state_dict['s']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, configs): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage) self.gen.load_state_dict(state_dict['a']) iterations = int( last_model_name[-15:-7]) if 'avg' in last_model_name else int( last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage) self.dis.load_state_dict(state_dict['b']) # Load optimizers #state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'), map_location=lambda storage, loc: storage) #self.dis_opt.load_state_dict(state_dict['dis']) #self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, configs, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, configs, iterations) if torch.__version__ != '0.4.1': for _ in range(iterations): self.gen_scheduler.step() self.dis_scheduler.step() print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load content classifier last_model_name = get_model_list(checkpoint_dir, "con_cla") state_dict = torch.load(last_model_name) self.content_classifier.load_state_dict(state_dict['con_cla']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.cla_opt.load_state_dict(state_dict['con_cla']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume_prev(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") print('resume from generator checkpoint named:', last_model_name) state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) try: iterations = int(float(last_model_name[4:11])) except: iterations = int(float(last_model_name.split('/')[-1][4:11])) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") print('resume from discriminator checkpoint named:', last_model_name) state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.models.gen.load_state_dict(state_dict['gen']) self.models.gen_test.load_state_dict(state_dict['gen_test']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.models.dis.load_state_dict(state_dict['dis']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b = self.gen_a iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b = self.dis_a # Load ID dis last_model_name = get_model_list(checkpoint_dir, "id") state_dict = torch.load(last_model_name) self.id_a.load_state_dict(state_dict['a']) self.id_b = self.id_a # Load optimizers try: state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.id_opt.load_state_dict(state_dict['id']) except: pass # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): """ Resume the training loading the network parameters Arguments: checkpoint_dir {string} -- path to the directory where the checkpoints are saved hyperparameters {dictionnary} -- dictionnary with all hyperparameters Returns: int -- number of iterations (used by the optimizer) """ # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) if self.gen_state == 0: self.gen_a.load_state_dict(state_dict["a"]) self.gen_b.load_state_dict(state_dict["b"]) elif self.gen_state == 1: self.gen.load_state_dict(state_dict["2"]) else: print("self.gen_state unknown value:", self.gen_state) # Load domain classifier if self.domain_classif == 1: last_model_name = get_model_list(checkpoint_dir, "domain_classif") state_dict = torch.load(last_model_name) self.domain_classifier.load_state_dict(state_dict["d"]) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict["a"]) self.dis_b.load_state_dict(state_dict["b"]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) self.dis_opt.load_state_dict(state_dict["dis"]) self.gen_opt.load_state_dict(state_dict["gen"]) if self.domain_classif == 1: self.dann_opt.load_state_dict(state_dict["dann"]) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters, iterations) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print("Resume from iteration %d" % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters, resume_epoch): print("--> " + checkpoint_dir) # Load generator. last_model_name = get_model_list(checkpoint_dir, "gen", resume_epoch) # print('\n',last_model_name) state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup", resume_epoch) state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict)
def resume(self, checkpoint_dir, hyperparameters): print("--> " + checkpoint_dir) # Load generator. last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict) epochs = int(last_model_name[-11:-3]) # Load supervised model. last_model_name = get_model_list(checkpoint_dir, "sup") state_dict = torch.load(last_model_name) self.sup.load_state_dict(state_dict) # Load discriminator. last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis.load_state_dict(state_dict) # Load optimizers. last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) for state in self.dis_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() for state in self.gen_opt.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # Reinitilize schedulers. self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs) print('Resume from epoch %d' % epochs) return epochs
def load(self, checkpoint_dir): last_model_name = get_model_list(checkpoint_dir, "enc") last_epoch = int(last_model_name.split('_')[1].split('.')[0]) encoder = torch.load(last_model_name) self.model.load_state_dict(encoder["model"]) self.opt.load_state_dict(encoder["opt"]) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.opt, T_max=self.T_max, eta_min=0, last_epoch=last_epoch) print(f"Resume encoder from epoch {last_epoch + 1}") return last_epoch + 1
def load_checkpoint(self, model_path=None): if not model_path: model_dir = self.model_dir model_path = get_model_list(model_dir, "gen") # last model state_dict = torch.load(model_path, map_location=self.device) self.gen.load_state_dict(state_dict['gen']) epochs = int(model_path[-7:-3]) print('Load from epoch %d' % epochs) return epochs
def resume(self, checkpoint_dir, hyperparameters, particular=False, checkpoint=''): # Load generators # if particular false then loads if not particular: last_model_name = get_model_list(checkpoint_dir, "vae") else: if checkpoint == '': sys.exit('Specified checkpoint path is empty') last_model_name = os.path.join(checkpoint_dir, checkpoint) state_dict = torch.load(last_model_name) self.vae.load_state_dict(state_dict)
def resume(self, checkpoint_dir, hps): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name, map_location=lambda storage, loc: storage) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'), map_location=lambda storage, loc: storage) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hps, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hps, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict['a']) self.gen_b.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict['a']) self.dis_b.load_state_dict(state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "student") state_dict = torch.load(last_model_name) self.generator.load_state_dict(state_dict['student']) self.best_result = state_dict['best_result'] epoch = int(last_model_name[-6:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "teacher") state_dict = torch.load(last_model_name) self.discriminator.load_state_dict(state_dict['teacher']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['teacher']) self.stu_opt.load_state_dict(state_dict['student']) # Re-initialize schedulers try: self.stu_scheduler = get_scheduler(self.stu_opt, param, epoch) except Exception as e: print('Warning: {}'.format(e)) print('Resume from epoch %d' % epoch) return epoch
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_AB.load_state_dict(state_dict['AB']) self.gen_BA.load_state_dict(state_dict['BA']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_A.load_state_dict(state_dict['A']) self.dis_B.load_state_dict(state_dict['B']) self.dis_2.load_state_dict(state_dict['2']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hp): last_model_name = get_model_list(checkpoint_dir, "gen") if last_model_name: state_dict = torch.load(last_model_name) self.model.gen.load_state_dict(state_dict['gen']) self.model.gen_test.load_state_dict(state_dict['gen_test']) iterations = int(last_model_name[-11:-3]) last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.model.dis.load_state_dict(state_dict['dis']) last_opt_name = get_model_list(checkpoint_dir, "optimizer") state_dict = torch.load(last_opt_name) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) self.dis_scheduler = get_scheduler(self.dis_opt, hp, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hp, iterations) else: iterations = 0 print(f'Resume GAN from iteration {iterations}') return iterations
def resume(self, checkpoint_dir, options): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.encoder.load_state_dict(state_dict['a']) self.decoder.load_state_dict(state_dict['b']) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.discriminator.load_state_dict(state_dict['a']) # Load optimizers last_model_name = get_model_list(checkpoint_dir, "opt") state_dict = torch.load(last_model_name) self.gen_opt.load_state_dict(state_dict['a']) self.dis_opt.load_state_dict(state_dict['b']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, options, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, options, iterations) print('Resume from iteration %d' % iterations) del state_dict, last_model_name torch.cuda.empty_cache() return iterations
def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.model.load_state_dict(state_dict['i']) iterations = int(last_model_name[-11:-3]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.gen_a.load_state_dict(state_dict["a"]) self.gen_b.load_state_dict(state_dict["b"]) iterations = int(last_model_name[-11:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.dis_a.load_state_dict(state_dict["a"]) self.dis_b.load_state_dict(state_dict["b"]) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) self.dis_opt.load_state_dict(state_dict["dis"]) self.gen_opt.load_state_dict(state_dict["gen"]) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print("Resume from iteration %d" % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None): # Load generators if gen_model is None: gen_model = get_model_list(checkpoint_dir, "gen") # last gen model gen_state_dict = torch.load(gen_model) self.gen_a.load_state_dict(gen_state_dict['a']) self.gen_b.load_state_dict(gen_state_dict['b']) iterations = int(gen_model[-11:-3]) # Load discriminators if dis_model is None: dis_model = get_model_list(checkpoint_dir, "dis") dis_state_dict = torch.load(dis_model) self.dis_a.load_state_dict(dis_state_dict['a']) self.dis_b.load_state_dict(dis_state_dict['b']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, param): # Load generators last_model_name = get_model_list(checkpoint_dir, "gen") state_dict = torch.load(last_model_name) self.generator.load_state_dict(state_dict['generator']) self.best_result = state_dict['best_result'] epoch = int(last_model_name[-6:-3]) # Load discriminators last_model_name = get_model_list(checkpoint_dir, "dis") state_dict = torch.load(last_model_name) self.discriminator_bg.load_state_dict(state_dict['bg']) self.discriminator_rf.load_state_dict(state_dict['rf']) # Load optimizers state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) self.dis_opt.load_state_dict(state_dict['dis']) self.gen_opt.load_state_dict(state_dict['gen']) # Reinitilize schedulers try: self.dis_scheduler = get_scheduler(self.dis_opt, param, epoch) self.gen_scheduler = get_scheduler(self.gen_opt, param, epoch) except Exception as e: print('Warning: {}'.format(e)) print('Resume from epoch %d' % epoch) return epoch
def resume(self, checkpoint_dir, params, version=None): model_name = get_model_list(checkpoint_dir, 'model', version) state_dict = torch.load(model_name) self.model.load_state_dict(state_dict) iterations = int(model_name[-12:-4]) if version is not None: assert version == iterations try: state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer.pth')) self.encoder_opt.load_state_dict(state_dict) except IOError: print 'use a new optimizer' # reinitialize scheduler self.encoder_scheduler = get_scheduler(self.encoder_opt, params, iterations) print('Resume from iteration %d' % iterations) return iterations
def resume(self, checkpoint_dir, hyperparameters, need_opt=True, path=None): # Load generators if (path == None): last_model_name = get_model_list(checkpoint_dir, "gen") else: last_model_name = path state_dict = torch.load(last_model_name) self.gen.load_state_dict(state_dict['a']) iterations = int(last_model_name[-11:-3]) if (need_opt): state_dict = torch.load( os.path.join(checkpoint_dir, 'optimizer_' + last_model_name[-11:-3] + '.pt')) self.gen_opt.load_state_dict(state_dict['gen']) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) print('Resume from iteration %d' % iterations) return iterations
def collaborative_private_model_mnist_train(args): device = 'cuda' if args.gpu else 'cpu' # 用于初始化模型的部分 train_dataset, test_dataset = get_public_dataset(args) models = { "2_layer_CNN": CNN_2layer_fc_model_removelogsoftmax, # 字典的函数类型 "3_layer_CNN": CNN_3layer_fc_model_removelogsoftmax } modelsindex = ["2_layer_CNN", "3_layer_CNN"] if args.new_collaborative_training: model_list, model_type_list = get_model_list(args.privateurl, modelsindex, models) else: model_list, model_type_list = get_model_list(args.Collaborativeurl, modelsindex, models) epoch_groups = MNIST_random(train_dataset, args.collaborative_epoch) train_loss = [] test_accuracy = [] for i in range(args.user_number): train_loss.append([]) for i in range(args.user_number): test_accuracy.append([]) for epoch in range(args.collaborative_epoch): train_batch_loss = [] for i in range(args.user_number): train_batch_loss.append([]) trainloader = DataLoader(DatasetSplit(train_dataset, list(epoch_groups[epoch])), batch_size=256, shuffle=True) for batch_idx, (images, labels) in enumerate(trainloader): images, labels = images.to(device), labels.to(device) # 初始化存储结果的东西 temp_sum_result = [[] for _ in range(len(labels))] for item in range(len(temp_sum_result)): for i in range(args.output_classes): temp_sum_result[item].append(0) # Make output together for n, model in enumerate(model_list): with torch.no_grad(): model.to(device) model.eval() outputs = model(images) pred_labels = outputs.tolist() # 转成list # print(pred_labels.shape) # torch.Size([128, 16]) # _,pred_labels = torch.max(outputs,1) # pred_labels = pred_labels.view(-1) # print(pred_labels.shape) # torch.Size([2048]) temp_sum_result = list_add(pred_labels, temp_sum_result) # 把每次的结果都给加到一起 # print(len(temp_sum_result)) # print(len(temp_sum_result[0])) # print(type(temp_sum_result)) # print(type(temp_sum_result[0])) # temp_sum_result = torch.stack(temp_sum_result) # torch.Size([10, 128, 16]) # temp_sum_result /=args.user_number # print(temp_sum_result.shape) # labels = torch.mean(temp_sum_result.float(),dim=0) # get the output # print(labels.shape) # torch.Size([128, 16]) temp_sum_result = get_avg_result( temp_sum_result, args.user_number) # 根据参与训练的时候用户把结果除以对应的数量 labels = torch.tensor(temp_sum_result) # print(labels.size()) # print((labels[0]).size()) labels = labels.to(device) for n, model in enumerate(model_list): model.to(device) model.train() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) criterion = nn.L1Loss(size_average=None, reduce=None, reduction='mean').to(device) optimizer.zero_grad() outputs = model(images) # torch.Size([128, 16]) loss = criterion(outputs, labels) loss.backward() optimizer.step() if batch_idx % 10 == 0: print( 'Collaborative traing : Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(n, model_type_list[n], epoch + 1, batch_idx * len(images), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item())) train_batch_loss[n].append(loss.item()) torch.save( model.state_dict(), 'Src/CollaborativeModel/LocalModel{}Type{}.pkl'.format( n, model_type_list[n])) for index in range(len(train_batch_loss)): loss_avg = sum(train_batch_loss[index]) / len( train_batch_loss[index]) train_loss[index].append(loss_avg) plt.figure() for index in range(len(train_loss)): plt.plot(range(len(train_loss[index])), train_loss[index]) plt.title('collaborative_train_losses') plt.xlabel('epoches') plt.ylabel('Train loss') plt.savefig('Src/Figure/collaborative_train_losses.png') plt.show() print('End Public Training')
def private_dataset_train(args): device = 'cuda' if args.gpu else 'cpu' # 用于初始化模型的部分 # 获得FEMNIST数据集! train_dataset, test_dataset = get_private_dataset_balanced(args) user_groups = FEMNIST_iid(train_dataset, args.user_number) models = { "2_layer_CNN": CNN_2layer_fc_model, # 字典的函数类型 "3_layer_CNN": CNN_3layer_fc_model } modelsindex = ["2_layer_CNN", "3_layer_CNN"] if args.new_private_training: model_list, model_type_list = get_model_list(args.initialurl, modelsindex, models) #model_list,model_type_list = get_model_list('Src/EmptyModel',modelsindex,models) else: #model_list,model_type_list = get_model_list(args.privateurl,modelsindex,models) model_list, model_type_list = get_model_list('Src/ModelNonIdFemnist', modelsindex, models) private_model_private_dataset_train_losses = [] private_model_private_dataset_validation_losses = [] for n, model in enumerate(model_list): print('train Local Model {} on Private Dataset'.format(n)) model.to(device) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) trainloader = DataLoader(DatasetSplit(train_dataset, list(user_groups[n])), batch_size=32, shuffle=True) testloader = DataLoader(test_dataset, batch_size=128, shuffle=True) criterion = nn.NLLLoss().to(device) train_epoch_losses = [] validation_epoch_losses = [] print('Begin Private Training') earlyStopping = EarlyStopping( patience=5, verbose=True, path='Src/ModelNonIdFemnist/LocalModel{}Type{}.pkl'.format( n, model_type_list[n], args.privateepoch)) for epoch in range(args.privateepoch): model.train() train_batch_losses = [] for batch_idx, (images, labels) in enumerate(trainloader): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if batch_idx % 5 == 0: print( 'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(n, model_type_list[n], epoch + 1, batch_idx * len(images), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item())) train_batch_losses.append(loss.item()) loss_avg = sum(train_batch_losses) / len(train_batch_losses) train_epoch_losses.append(loss_avg) model.eval() val_batch_losses = [] for batch_idx, (images, labels) in enumerate(testloader): images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) if batch_idx % 5 == 0: print( 'Local Model {} Type {} Val Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(n, model_type_list[n], epoch + 1, batch_idx * len(images), len(testloader.dataset), 100. * batch_idx / len(testloader), loss.item())) val_batch_losses.append(loss.item()) loss_avg = sum(val_batch_losses) / len(val_batch_losses) validation_epoch_losses.append(loss_avg) earlyStopping(loss_avg, model) if earlyStopping.early_stop: print("Early stopping") break private_model_private_dataset_train_losses.append(train_epoch_losses) private_model_private_dataset_validation_losses.append( validation_epoch_losses) plt.figure() for i, val in enumerate(private_model_private_dataset_train_losses): print(val) plt.plot(range(len(val)), val, label='model :' + str(i)) plt.legend(loc='best') plt.title('private_model_private_non_iid_dataset_train_demo_losses') plt.xlabel('epoches') plt.ylabel('Train loss') x_major_locator = MultipleLocator(1) # 把x轴的刻度间隔设置为1,并存在变量里 ax = plt.gca() # ax为两条坐标轴的实例 ax.xaxis.set_major_locator(x_major_locator) # 把x轴的主刻度设置为1的倍数 plt.xlim(0, args.privateepoch) plt.savefig( 'Src/Figure/private_model_private_non_iid_dataset_train_demo_losses.png' ) plt.show() plt.figure() for i, val in enumerate(private_model_private_dataset_validation_losses): print(val) plt.plot(range(len(val)), val, label='model :' + str(i)) plt.legend(loc='best') plt.title('private_model_private_non_iid_dataset_validation_demo_losses') plt.xlabel('epoches') plt.ylabel('Validation loss') x_major_locator = MultipleLocator(1) # 把x轴的刻度间隔设置为1,并存在变量里 ax = plt.gca() # ax为两条坐标轴的实例 ax.xaxis.set_major_locator(x_major_locator) # 把x轴的主刻度设置为1的倍数 plt.xlim(0, args.privateepoch) plt.savefig( 'Src/Figure/private_model_private_non_iid_dataset_validation_demo_losses.png' ) plt.show() print('End Private Training')
data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['crop_image_height'], crop=False) # config['vgg_model_path'] = opts.output_path if opts.trainer == 'MUNIT': style_dim = config['gen']['style_dim'] trainer = MUNIT_Trainer(config) elif opts.trainer == 'UNIT': trainer = UNIT_Trainer(config) else: sys.exit("Only support MUNIT|UNIT") last_model_name = get_model_list(opts.checkpoint, "gen") state_dict = torch.load(last_model_name) trainer.gen_a.load_state_dict(state_dict['a']) trainer.gen_b.load_state_dict(state_dict['b']) last_model_name = get_model_list(opts.checkpoint, "submodel") state_dict = torch.load(last_model_name) trainer.a2b.load_state_dict(state_dict['a2b']) trainer.b2a.load_state_dict(state_dict['b2a']) trainer.cuda() trainer.eval() encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function get_style = trainer.b2a
output_directory = os.path.join(opts.output_path + "/outputs", model_name) checkpoint_directory = os.path.join(output_directory, 'checkpoints') os.makedirs(checkpoint_directory, exist_ok=True) shutil.copy(opts.config, os.path.join( output_directory, 'config.yaml')) # copy config file to output folder dataloader = DataLoader( WCDataset(config['dataset_path']), batch_size=config['batch_size'], shuffle=True, drop_last=True, num_workers=config['num_workers'], ) if opts.resume: # opts.resume=False last_model_name = get_model_list(checkpoint_directory) iteration = int(last_model_name[:-4]) net = get_net(last_model_name, dataloader.dataset.class_num) optimizer = get_optimizer(net, config) scheducer = get_scheducer(optimizer, config, iteration) print('Resume from iteration %d' % iteration) else: iteration = 0 net = get_net(config['weight_path'], dataloader.dataset.class_num) optimizer = get_optimizer(net, config) scheducer = get_scheducer(optimizer, config) criterion = AngleLoss() # criterion = nn.CrossEntropyLoss() train(net, dataloader, criterion, optimizer, scheducer, config, train_writer, checkpoint_directory, iteration)
def collaborative_private_model_femnist_train(args): device = 'cuda' if args.gpu else 'cpu' # 用于初始化模型的部分 # 获得FEMNIST数据集! train_dataset, test_dataset = get_private_dataset_balanced(args) user_groups = FEMNIST_iid(train_dataset, args.user_number) models = { "2_layer_CNN": CNN_2layer_fc_model, # 字典的函数类型 "3_layer_CNN": CNN_3layer_fc_model } modelsindex = ["2_layer_CNN", "3_layer_CNN"] model_list, model_type_list = get_model_list(args.Collaborativeurl, modelsindex, models) print('Begin Private Training') private_model_private_dataset_train_losses = [] for n, model in enumerate(model_list): print('train Local Model {} on Private Dataset'.format(n)) model.to(device) model.train() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) trainloader = DataLoader(DatasetSplit(train_dataset, list(user_groups[n])), batch_size=5, shuffle=True) criterion = nn.NLLLoss().to(device) train_epoch_losses = [] for epoch in range(args.Communication_private_epoch): train_batch_losses = [] for batch_idx, (images, labels) in enumerate(trainloader): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if batch_idx % 5 == 0: print( 'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(n, model_type_list[n], epoch + 1, batch_idx * len(images), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item())) train_batch_losses.append(loss.item()) loss_avg = sum(train_batch_losses) / len(train_batch_losses) train_epoch_losses.append(loss_avg) torch.save( model.state_dict(), 'Src/CollaborativeModel/LocalModel{}Type{}.pkl'.format( n, model_type_list[n], args.epoch)) private_model_private_dataset_train_losses.append(train_epoch_losses) plt.figure() for i, val in enumerate(private_model_private_dataset_train_losses): print(val) plt.plot(range(len(val)), val) plt.title('collaborative_private_model_private_dataset_train_losses') plt.xlabel('epoches') plt.ylabel('Train loss') plt.savefig( 'Src/Figure/collaborative_private_model_private_dataset_train_losses.png' ) plt.show() print('End Private Training')
required=True) parser.add_argument('-s', '--save', help='Path to save models and stats', default=None, required=True) args = parser.parse_args() data_path = args.data save_path = args.save train_mode = args.mode num_epochs = args.epoch batch_size = args.batch start_time = str(int(time.time())) list_of_models = get_model_list() models_to_test = ['alexnet'] use_gpu = torch.cuda.is_available() plot_colors = get_palet(len(models_to_test)) accfinal = 0 stats_file = save_path + '/' + start_time + '_' + os.path.basename( data_path) + '_' + args.mode + '_stats.csv' number_classes = len(glob.glob(data_path + '/train/*')) stats = [] print('Data:', data_path) print('Mode:', train_mode)
def continue_train_models(args): device = 'cuda' if args.gpu else 'cpu' # 用于初始化模型的部分 train_dataset, test_dataset = get_public_dataset(args) models = { "2_layer_CNN": CNN_2layer_fc_model, # 字典的函数类型 "3_layer_CNN": CNN_3layer_fc_model } modelsindex = ["2_layer_CNN", "3_layer_CNN"] model_list, model_type_list = get_model_list(args.initialurl, modelsindex, models) private_model_public_dataset_train_losses = [] for n, model in enumerate(model_list): print('Continue train Local Model {}'.format(n)) model.to(device) model.train() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.5) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) trainloader = DataLoader(train_dataset, batch_size=512, shuffle=True) criterion = nn.NLLLoss().to(device) train_epoch_losses = [] print('Begin Public Training') for epoch in range(args.continue_epoch): train_batch_losses = [] for batch_idx, (images, labels) in enumerate(trainloader): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if batch_idx % 50 == 0: print( 'Local Model {} Type {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(n, model_type_list[n], epoch + 1, batch_idx * len(images), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item())) train_batch_losses.append(loss.item()) loss_avg = sum(train_batch_losses) / len(train_batch_losses) train_epoch_losses.append(loss_avg) torch.save( model.state_dict(), 'Src/Model/LocalModel{}Type{}Epoch{}.pkl'.format( n, model_type_list[n], args.epoch)) private_model_public_dataset_train_losses.append(train_epoch_losses) plt.figure() for i, val in enumerate(private_model_public_dataset_train_losses): print(val) plt.plot(range(len(val)), val) plt.xlabel('epoches') plt.ylabel('Train loss') plt.savefig( 'Src/Figure/private_model_public_dataset_train_continue_losses.png') plt.show() print('End Public Training')