def test(model_path, submit_csv=hparams.submit_file, submit_file=hparams.submit_file, best_thresh=None): test_dataset = AudioData(data_csv=submit_csv, data_file=submit_file, ds_type='submit', transform=transforms.Compose([ transforms.ToTensor(), ])) test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=False, num_workers=2) discriminator = Discriminator().to(hparams.gpu_device) if hparams.cuda: discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids) checkpoint = torch.load(model_path, map_location=hparams.gpu_device) discriminator.load_state_dict(checkpoint['discriminator_state_dict']) discriminator = discriminator.eval() # print('Model loaded') Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor print('Testing model on {0} examples. '.format(len(test_dataset))) with torch.no_grad(): pred_logits_list = [] labels_list = [] img_names_list = [] # for _ in range(hparams.repeat_infer): for (inp, labels, img_names) in tqdm(test_loader): inp = Variable(inp.float(), requires_grad=False) labels = Variable(labels.long(), requires_grad=False) inp = inp.to(hparams.gpu_device) labels = labels.to(hparams.gpu_device) inp = inp.view(-1, 1, 640, 64) inp = torch.cat([inp]*3, dim=1) pred_logits = discriminator(inp) pred_logits_list.append(pred_logits) labels_list.append(labels) img_names_list.append(img_names) pred_logits = torch.cat(pred_logits_list, dim=0) labels = torch.cat(labels_list, dim=0) pred_labels = pred_logits.max(1)[1] with open
class face_learner(object): def __init__(self, conf, inference=False): if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) self.growup = GrowUP().to(conf.device) self.discriminator = Discriminator().to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) if conf.discriminator: self.child_loader, self.adult_loader = get_train_loader_d(conf) os.makedirs(conf.log_path, exist_ok=True) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) # Will not use anymore if conf.use_dp: self.model = nn.DataParallel(self.model) self.head = nn.DataParallel(self.head) print(self.class_num) print(conf) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.discriminator: self.optimizer_g = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_g2 = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_d = optim.Adam(self.discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer2 = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.finetune_model_path is not None: self.optimizer = optim.SGD([{ 'params': paras_wo_bn, 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print('optimizers generated') self.board_loss_every = len(self.loader) // 100 self.evaluate_every = len(self.loader) // 2 self.save_every = len(self.loader) dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore" self.lfw = np.load( os.path.join(dataset_root, "lfw_align_112_list.npy")).astype(np.float32) self.lfw_issame = np.load( os.path.join(dataset_root, "lfw_align_112_label.npy")) self.fgnetc = np.load( os.path.join(dataset_root, "FGNET_new_align_list.npy")).astype(np.float32) self.fgnetc_issame = np.load( os.path.join(dataset_root, "FGNET_new_align_label.npy")) else: # Will not use anymore # self.model = nn.DataParallel(self.model) self.threshold = conf.threshold def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_scalar('{}_negative_wrong'.format(db_name), negative_wrong, self.step) self.writer.add_scalar('{}_positive_wrong'.format(db_name), positive_wrong, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:] = l2_norm(emb_batch).cpu() else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist( embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size]) embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size]) carray1 = carray[::2, ] carray2 = carray[1::2, ] with torch.no_grad(): while idx + conf.batch_size <= len(carray1): batch = torch.tensor(carray1[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings1[idx:idx + conf.batch_size] = self.growup( self.model(batch.to(conf.device))).cpu() idx += conf.batch_size if idx < len(carray1): batch = torch.tensor(carray1[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:] = l2_norm(emb_batch).cpu() else: embeddings1[idx:] = self.growup( self.model(batch.to(conf.device))).cpu() while idx + conf.batch_size <= len(carray2): batch = torch.tensor(carray2[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings2[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray2): batch = torch.tensor(carray2[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:] = l2_norm(emb_batch).cpu() else: embeddings2[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate_child( embeddings1, embeddings2, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def zero_grad(self): self.optimizer.zero_grad() self.optimizer_g.zero_grad() self.optimizer_d.zero_grad() def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() for imgs, labels, ages in tqdm(iter(self.loader)): self.optimizer.zero_grad() imgs = imgs.to(conf.device) labels = labels.to(conf.device) embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. # added wrong on evaluations if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') # LFW evaluation accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate( conf, self.lfw, self.lfw_issame) # NEGATIVE WRONG wrong_list = np.where((self.lfw_issame == False) & (dist < best_threshold))[0] negative_wrong = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.lfw_issame == True) & (dist > best_threshold))[0] positive_wrong = len(wrong_list) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong) # FGNETC evaluation accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) # NEGATIVE WRONG wrong_list = np.where((self.fgnetc_issame == False) & (dist2 < best_threshold2))[0] negative_wrong2 = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.fgnetc_issame == True) & (dist2 > best_threshold2))[0] positive_wrong2 = len(wrong_list) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2, negative_wrong2, positive_wrong2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? if conf.finetune_model_path is not None: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) else: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 print('Horray!') def train_with_growup(self, conf, epochs): ''' Our method ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() self.optimizer_g2.zero_grad() self.growup.train() c = (ages == 0) # select children for enhancement embeddings = self.model(imgs) if sum(c) > 1: # there might be no childern in loader's batch embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat elif sum(c) == 1: self.growup.eval() embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() self.optimizer_g2.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() self.growup.train() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] embeddings_a_hat = self.growup(embeddings_c) labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer_g.zero_grad() embeddings_c = self.model(imgs_c) embeddings_a_hat = self.growup(embeddings_c) pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c) g_total_loss = g_loss + 10 * l1_loss g_total_loss.backward() # g_loss.backward() self.optimizer_g.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant(self, conf, epochs): ''' Our method, without growup ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant2(self, conf, epochs): ''' Our method, without growup, using paired dataset TODO ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def analyze_angle(self, conf, name): ''' Only works on age labeled vgg dataset, agedb dataset ''' angle_table = [{ 0: set(), 1: set(), 2: set(), 3: set(), 4: set(), 5: set(), 6: set(), 7: set() } for i in range(self.class_num)] # batch = 0 # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device) if conf.resume_analysis: self.loader = [] for imgs, labels, ages in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) ages = ages.to(conf.device) embeddings = self.model(imgs) if conf.use_dp: kernel_norm = l2_norm(self.head.module.kernel, axis=0) cos_theta = torch.mm(embeddings, kernel_norm) cos_theta = cos_theta.clamp(-1, 1) else: cos_theta = self.head.get_angle(embeddings) thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta))) for i in range(len(thetas)): age_bin = 7 if ages[i] < 26: age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2 elif ages[i] < 66: age_bin = int(((ages[i] + 4) // 10).item()) angle_table[labels[i]][age_bin].add( thetas[i][labels[i]].item()) if conf.resume_analysis: with open('analysis/angle_table.pkl', 'rb') as f: angle_table = pickle.load(f) else: with open('analysis/angle_table.pkl', 'wb') as f: pickle.dump(angle_table, f) count, avg_angle = [], [] for i in range(self.class_num): count.append( [len(single_set) for single_set in angle_table[i].values()]) avg_angle.append([ sum(list(single_set)) / len(single_set) if len(single_set) else 0 # if set() size is zero, avg is zero for single_set in angle_table[i].values() ]) count_df = pd.DataFrame(count) avg_angle_df = pd.DataFrame(avg_angle) with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format( conf.data_mode, name)) as writer: count_df.to_excel(writer, sheet_name='count') avg_angle_df.to_excel(writer, sheet_name='avg_angle') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def schedule_lr2(self): for params in self.optimizer2.param_groups: params['lr'] /= 10 print(self.optimizer2) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = transforms.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def save_best_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if conf.discriminator: torch.save( self.growup.state_dict(), str(save_path) + ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False, analyze=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str)))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) if not analyze: self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
# bert_optimizer.step() # gen_optimizer.step() # dis_optimizer.step() tr_g_loss += g_loss.item() tr_d_loss += d_loss.item() nb_tr_examples += src_input_ids.size(0) nb_tr_steps += 1 global_step += 1 tr_g_loss /= nb_tr_steps tr_d_loss /= nb_tr_steps # VALIDATION bert.eval() discriminator.eval() all_preds = np.array([]) all_label_ids = np.array([]) eval_loss = 0 nb_eval_steps = 0 for src_input_ids, src_input_mask, label_ids in val_dataloader: src_input_ids = src_input_ids.to(device) src_input_mask = src_input_mask.to(device) label_ids = label_ids.to(device) with torch.no_grad(): _, doc_rep = bert(src_input_ids, attention_mask=src_input_mask) _, logits, probs = discriminator(doc_rep) print(probs) probs = torch.nn.functional.softmax(probs[:, :-1], dim=-1)
class Solver(object): #### def __init__(self, args): self.args = args self.name = ( '%s_etaS_%s_etaH_%s_lamklMin_%s_lamklMax_%s' + \ '_gamma_%s_zDim_%s' ) % \ ( args.dataset, args.etaS, args.etaH, \ args.lamklMin, args.lamklMax, args.gamma, args.z_dim ) # to be appended by run_id self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = int(args.max_iter) # do it every specified iters self.print_iter = args.print_iter self.ckpt_save_iter = args.ckpt_save_iter self.output_save_iter = args.output_save_iter # data info self.dset_dir = args.dset_dir self.dataset = args.dataset if args.dataset.endswith('dsprites'): self.nc = 1 else: self.nc = 3 # groundtruth factor labels (only available for "dsprites") if self.dataset == 'dsprites': # latent factor = (color, shape, scale, orient, pos-x, pos-y) # color = {1} (1) # shape = {1=square, 2=oval, 3=heart} (3) # scale = {0.5, 0.6, ..., 1.0} (6) # orient = {2*pi*(k/39)}_{k=0}^39 (40) # pos-x = {k/31}_{k=0}^31 (32) # pos-y = {k/31}_{k=0}^31 (32) # (number of variations = 1*3*6*40*32*32 = 737280) latent_values = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_values.npy'), encoding='latin1') self.latent_values = latent_values[:, [1, 2, 3, 4, 5]] # latent values (actual values);(737280 x 5) latent_classes = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_classes.npy'), encoding='latin1') self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]] # classes ({0,1,...,K}-valued); (737280 x 5) self.latent_sizes = np.array([3, 6, 40, 32, 32]) self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter # groundtruth factor labels elif self.dataset == 'oval_dsprites': latent_classes = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_classes.npy'), encoding='latin1') idx = np.where(latent_classes[:, 1] == 1)[0] # "oval" shape only self.latent_classes = latent_classes[idx, :] self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]] # classes ({0,1,...,K}-valued); (245760 x 4) latent_values = np.load(os.path.join(self.dset_dir, 'dsprites-dataset', 'latents_values.npy'), encoding='latin1') self.latent_values = latent_values[idx, :] self.latent_values = self.latent_values[:, [2, 3, 4, 5]] # latent values (actual values);(245760 x 4) self.latent_sizes = np.array([6, 40, 32, 32]) self.N = self.latent_values.shape[0] if args.eval_metrics: self.eval_metrics = True self.eval_metrics_iter = args.eval_metrics_iter # networks and optimizers self.batch_size = args.batch_size self.z_dim = args.z_dim self.etaS = args.etaS self.etaH = args.etaH self.lamklMin = args.lamklMin self.lamklMax = args.lamklMax self.gamma = args.gamma self.lr_VAE = args.lr_VAE self.beta1_VAE = args.beta1_VAE self.beta2_VAE = args.beta2_VAE # self.lr_rvec = args.lr_rvec # self.beta1_rvec = args.beta1_rvec # self.beta2_rvec = args.beta2_rvec self.lr_D = args.lr_D self.beta1_D = args.beta1_D self.beta2_D = args.beta2_D # visdom setup self.viz_on = args.viz_on if self.viz_on: self.win_id = dict(DZ='win_DZ', recon='win_recon', kl='win_kl', rvS='win_rvS', rvH='win_rvH') self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon', 'kl', 'rvS', 'rvH') if self.eval_metrics: self.win_id['metrics'] = 'win_metrics' import visdom self.viz_port = args.viz_port # port number, eg, 8097 self.viz = visdom.Visdom(port=self.viz_port) self.viz_ll_iter = args.viz_ll_iter self.viz_la_iter = args.viz_la_iter self.viz_init() # create dirs: "records", "ckpts", "outputs" (if not exist) mkdirs("records") mkdirs("ckpts") mkdirs("outputs") # set run id if args.run_id < 0: # create a new id k = 0 rfname = os.path.join("records", self.name + '_run_0.txt') while os.path.exists(rfname): k += 1 rfname = os.path.join("records", self.name + '_run_%d.txt' % k) self.run_id = k else: # user-provided id self.run_id = args.run_id # finalize name self.name = self.name + '_run_' + str(self.run_id) # records (text file to store console outputs) self.record_file = 'records/%s.txt' % self.name # checkpoints self.ckpt_dir = os.path.join("ckpts", self.name) # outputs self.output_dir_recon = os.path.join("outputs", self.name + '_recon') # dir for reconstructed images self.output_dir_synth = os.path.join("outputs", self.name + '_synth') # dir for synthesized images self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl') # dir for latent traversed images #### create a new model or load a previously saved model self.ckpt_load_iter = args.ckpt_load_iter if self.ckpt_load_iter == 0: # create a new model # create a vae model if args.dataset.endswith('dsprites'): self.encoder = Encoder1(self.z_dim) self.decoder = Decoder1(self.z_dim) else: pass #self.VAE = FactorVAE2(self.z_dim) # create a relevance vector self.rvec = RelevanceVector(self.z_dim) # create a discriminator model self.D = Discriminator(self.z_dim) else: # load a previously saved model print('Loading saved models (iter: %d)...' % self.ckpt_load_iter) self.load_checkpoint() print('...done') if self.use_cuda: print('Models moved to GPU...') self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.rvec = self.rvec.cuda() self.D = self.D.cuda() print('...done') # get VAE parameters (and rv parameters) vae_params = list(self.encoder.parameters()) + \ list(self.decoder.parameters()) + list(self.rvec.parameters()) # get discriminator parameters dis_params = list(self.D.parameters()) # create optimizers self.optim_vae = optim.Adam(vae_params, lr=self.lr_VAE, betas=[self.beta1_VAE, self.beta2_VAE]) self.optim_dis = optim.Adam(dis_params, lr=self.lr_D, betas=[self.beta1_D, self.beta2_D]) #### def train(self): self.set_mode(train=True) ones = torch.ones(self.batch_size, dtype=torch.long) zeros = torch.zeros(self.batch_size, dtype=torch.long) if self.use_cuda: ones = ones.cuda() zeros = zeros.cuda() # prepare dataloader (iterable) print('Start loading data...') self.data_loader = create_dataloader(self.args) print('...done') # iterators from dataloader iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) iter_per_epoch = min(len(iterator1), len(iterator2)) start_iter = self.ckpt_load_iter + 1 epoch = int(start_iter / iter_per_epoch) for iteration in range(start_iter, self.max_iter + 1): # reset data iterators for each epoch if iteration % iter_per_epoch == 0: print('==== epoch %d done ====' % epoch) epoch += 1 iterator1 = iter(self.data_loader) iterator2 = iter(self.data_loader) #============================================ # TRAIN THE VAE (ENC & DEC) #============================================ # sample a mini-batch X, ids = next(iterator1) # (n x C x H x W) if self.use_cuda: X = X.cuda() # enc(X) mu, std, logvar = self.encoder(X) # relevance vector rvlogit, rv = self.rvec() # kl loss kls = -0.5 * (1 + logvar - mu**2 - std**2) # (n x z_dim) klsum = kls.sum(1).mean() lamkl = self.lamklMax - (self.lamklMax - self.lamklMin) * rv loss_kl = (lamkl * kls).sum(1).mean() # reparam'ed samples if self.use_cuda: Eps = torch.cuda.FloatTensor(mu.shape).normal_() else: Eps = torch.randn(mu.shape) Z = mu + Eps * std # dec(Z) X_recon = self.decoder(Z) # recon loss loss_recon = F.binary_cross_entropy_with_logits( X_recon, X, reduction='sum').div(X.size(0)) # dis(rv*Z) DZ = self.D(rv * Z) # tc loss loss_tc = (DZ[:, 0] - DZ[:, 1]).mean() # L1 (sparseness) loss loss_sparse = rv.sum() # entropy loss loss_entropy = F.binary_cross_entropy_with_logits(rvlogit, rv, reduction='sum') #loss_entropy = (rv*(1-rv)).sum() # total loss for vae vae_loss = loss_recon + loss_kl + self.gamma*loss_tc + \ self.etaS*loss_sparse + self.etaH*loss_entropy # update vae self.optim_vae.zero_grad() vae_loss.backward() self.optim_vae.step() #============================================ # TRAIN THE DISCRIMINATOR #============================================ # sample a mini-batch X2, ids = next(iterator2) # (n x C x H x W) if self.use_cuda: X2 = X2.cuda() # enc(X2) mu, std, _ = self.encoder(X2) # reparam'ed samples if self.use_cuda: Eps = torch.cuda.FloatTensor(mu.shape).normal_() else: Eps = torch.randn(mu.shape) Z = mu + Eps * std # relevance vector _, rv = self.rvec() RZ = rv * Z # dis(RZ) DZ = self.D(RZ) # dim-wise permutated Z over the mini-batch perm_Z = [] for zj in RZ.split(1, 1): idx = torch.randperm(Z.size(0)) perm_zj = zj[idx] perm_Z.append(perm_zj) RZ_perm = torch.cat(perm_Z, 1) RZ_perm = RZ_perm.detach() # dis(RZ_perm) DZ_perm = self.D(RZ_perm) # discriminator loss dis_loss = 0.5 * (F.cross_entropy(DZ, zeros) + F.cross_entropy(DZ_perm, ones)) # update discriminator self.optim_dis.zero_grad() dis_loss.backward() self.optim_dis.step() # print the losses if iteration % self.print_iter == 0: prn_str = ( '[iter %d (epoch %d)] vae_loss: %.3f | ' + \ 'dis_loss: %.3f\n ' + \ '(recon: %.3f, kl: %.3f, tc: %.3f, L1: %.3f, H: %.3f)' \ ) % \ ( iteration, epoch, vae_loss.item(), dis_loss.item(), loss_recon.item(), klsum.item(), loss_tc.item(), loss_sparse.item(), loss_entropy.item() ) prn_str += '\n rv = {}'.format( rv.detach().cpu().numpy().round(2)) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str, )) record.close() # save model parameters if iteration % self.ckpt_save_iter == 0: self.save_checkpoint(iteration) # save output images (recon, synth, etc.) if iteration % self.output_save_iter == 0: # 1) save the recon images self.save_recon(iteration, X, torch.sigmoid(X_recon).data) # 2) save the synth images self.save_synth(iteration, howmany=100) # 3) save the latent traversed images if self.dataset.lower() == '3dchairs': self.save_traverse(iteration, limb=-2, limu=2, inter=0.5) else: self.save_traverse(iteration, limb=-3, limu=3, inter=0.1) # (visdom) insert current line stats if self.viz_on and (iteration % self.viz_ll_iter == 0): # compute discriminator accuracy p_DZ = F.softmax(DZ, 1)[:, 0].detach() p_DZ_perm = F.softmax(DZ_perm, 1)[:, 0].detach() # insert line stats self.line_gather.insert(iter=iteration, p_DZ=p_DZ.mean().item(), p_DZ_perm=p_DZ_perm.mean().item(), recon=loss_recon.item(), kl=klsum.item(), rvS=loss_sparse.item(), rvH=loss_entropy.item()) # (visdom) visualize line stats (then flush out) if self.viz_on and (iteration % self.viz_la_iter == 0): self.visualize_line() self.line_gather.flush() # evaluate metrics if self.eval_metrics and (iteration % self.eval_metrics_iter == 0): metric1, _ = self.eval_disentangle_metric1() metric2, _ = self.eval_disentangle_metric2() prn_str = ( '********\n[iter %d (epoch %d)] ' + \ 'metric1 = %.4f, metric2 = %.4f\n********' ) % \ (iteration, epoch, metric1, metric2) print(prn_str) if self.record_file: record = open(self.record_file, 'a') record.write('%s\n' % (prn_str, )) record.close() # (visdom) visulaize metrics if self.viz_on: self.visualize_line_metrics(iteration, metric1, metric2) #### def eval_disentangle_metric1(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader(self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch Xb, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: Xb = Xb.cuda() # enc(Xb) mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor fixed" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # a true factor (id and class value) to fix fac_id = j fac_class = np.random.randint(self.latent_sizes[fac_id]) # randomly select images (with the fixed factor) indices = np.where(self.latent_classes[:, fac_id] == fac_class)[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if Xb.shape[0] < 1: # no more samples continue if self.use_cuda: Xb = Xb.cuda() mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = fac_id i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification smallest_var_dims = np.argmin(vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[smallest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric1 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric1, C #### def eval_disentangle_metric2(self): # some hyperparams num_pairs = 800 # # data pairs (d,y) for majority vote classification bs = 50 # batch size nsamps_per_factor = 100 # samples per factor nsamps_agn_factor = 5000 # factor-agnostic samples self.set_mode(train=False) # 1) estimate variances of latent points factor agnostic dl = DataLoader(self.data_loader.dataset, batch_size=bs, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) iterator = iter(dl) M = [] for ib in range(int(nsamps_agn_factor / bs)): # sample a mini-batch Xb, _ = next(iterator) # (bs x C x H x W) if self.use_cuda: Xb = Xb.cuda() # enc(Xb) mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample vairance and mean of latent points for each dim vars_agn_factor = np.var(M, 0) # 2) estimatet dim-wise vars of latent points with "one factor varied" factor_ids = range(0, len(self.latent_sizes)) # true factor ids vars_per_factor = np.zeros([num_pairs, self.z_dim]) true_factor_ids = np.zeros(num_pairs, np.int) # true factor ids # prepare data pairs for majority-vote classification i = 0 for j in factor_ids: # for each factor # repeat num_paris/num_factors times for r in range(int(num_pairs / len(factor_ids))): # randomly choose true factors (id's and class values) to fix fac_ids = list(np.setdiff1d(factor_ids, j)) fac_classes = \ [ np.random.randint(self.latent_sizes[k]) for k in fac_ids ] # randomly select images (with the other factors fixed) if len(fac_ids) > 1: indices = np.where( np.sum(self.latent_classes[:, fac_ids] == fac_classes, 1) == len(fac_ids))[0] else: indices = np.where( self.latent_classes[:, fac_ids] == fac_classes)[0] np.random.shuffle(indices) idx = indices[:nsamps_per_factor] M = [] for ib in range(int(nsamps_per_factor / bs)): Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]] if Xb.shape[0] < 1: # no more samples continue if self.use_cuda: Xb = Xb.cuda() mub, _, _ = self.encoder(Xb) # (bs x z_dim) M.append(mub.cpu().detach().numpy()) M = np.concatenate(M, 0) # estimate sample var and mean of latent points for each dim if M.shape[0] >= 2: vars_per_factor[i, :] = np.var(M, 0) else: # not enough samples to estimate variance vars_per_factor[i, :] = 0.0 # true factor id (will become the class label) true_factor_ids[i] = j i += 1 # 3) evaluate majority vote classification accuracy # inputs in the paired data for classification largest_var_dims = np.argmax(vars_per_factor / (vars_agn_factor + 1e-20), axis=1) # contingency table C = np.zeros([self.z_dim, len(factor_ids)]) for i in range(num_pairs): C[largest_var_dims[i], true_factor_ids[i]] += 1 num_errs = 0 # # misclassifying errors of majority vote classifier for k in range(self.z_dim): num_errs += np.sum(C[k, :]) - np.max(C[k, :]) metric2 = (num_pairs - num_errs) / num_pairs # metric = accuracy self.set_mode(train=True) return metric2, C #### def save_recon(self, iters, true_images, recon_images): # make a merge of true and recon, eg, # merged[0,...] = true[0,...], # merged[1,...] = recon[0,...], # merged[2,...] = true[1,...], # merged[3,...] = recon[1,...], ... n = true_images.shape[0] perm = torch.arange(0, 2 * n).view(2, n).transpose(1, 0) perm = perm.contiguous().view(-1) merged = torch.cat([true_images, recon_images], dim=0) merged = merged[perm, :].cpu() # save the results as image fname = os.path.join(self.output_dir_recon, 'recon_%s.jpg' % iters) mkdirs(self.output_dir_recon) save_image(tensor=merged, filename=fname, nrow=2 * int(np.sqrt(n)), pad_value=1) #### def save_synth(self, iters, howmany=100): self.set_mode(train=False) decoder = self.decoder Z = torch.randn(howmany, self.z_dim) if self.use_cuda: Z = Z.cuda() # do synthesis X = torch.sigmoid(decoder(Z)).data.cpu() # save the results as image fname = os.path.join(self.output_dir_synth, 'synth_%s.jpg' % iters) mkdirs(self.output_dir_synth) save_image(tensor=X, filename=fname, nrow=int(np.sqrt(howmany)), pad_value=1) self.set_mode(train=True) #### def save_traverse(self, iters, limb=-3, limu=3, inter=2 / 3, loc=-1): self.set_mode(train=False) encoder = self.encoder decoder = self.decoder interpolation = torch.arange(limb, limu + 0.001, inter) i = np.random.randint(self.N) random_img = self.data_loader.dataset.__getitem__(i)[0] if self.use_cuda: random_img = random_img.cuda() random_img = random_img.unsqueeze(0) random_img_zmu, _, _ = encoder(random_img) if self.dataset.lower() == 'dsprites': fixed_idx1 = 87040 # square fixed_idx2 = 332800 # ellipse fixed_idx3 = 578560 # heart fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed_square': fixed_img1, 'fixed_ellipse': fixed_img2, 'fixed_heart': fixed_img3, 'random_img': random_img } Z = { 'fixed_square': fixed_img_zmu1, 'fixed_ellipse': fixed_img_zmu2, 'fixed_heart': fixed_img_zmu3, 'random_img': random_img_zmu } elif self.dataset.lower() == 'oval_dsprites': fixed_idx1 = 87040 # oval1 fixed_idx2 = 220045 # oval2 fixed_idx3 = 178560 # oval3 fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] if self.use_cuda: fixed_img1 = fixed_img1.cuda() fixed_img1 = fixed_img1.unsqueeze(0) fixed_img_zmu1, _, _ = encoder(fixed_img1) fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] if self.use_cuda: fixed_img2 = fixed_img2.cuda() fixed_img2 = fixed_img2.unsqueeze(0) fixed_img_zmu2, _, _ = encoder(fixed_img2) fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] if self.use_cuda: fixed_img3 = fixed_img3.cuda() fixed_img3 = fixed_img3.unsqueeze(0) fixed_img_zmu3, _, _ = encoder(fixed_img3) IMG = { 'fixed1': fixed_img1, 'fixed2': fixed_img2, 'fixed3': fixed_img3, 'random_img': random_img } Z = { 'fixed1': fixed_img_zmu1, 'fixed2': fixed_img_zmu2, 'fixed3': fixed_img_zmu3, 'random_img': random_img_zmu } # elif self.dataset.lower() == 'celeba': # # fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg' # fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg' # fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg' # fixed_idx4 = 70059 # 'CelebA/img_align_celeba/070060.jpg' # # fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] # fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) # fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] # # fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] # fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) # fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] # # fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] # fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) # fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] # # fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0] # fixed_img4 = fixed_img4.to(self.device).unsqueeze(0) # fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim] # # Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, # 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4, # 'random':random_img_zmu} # # elif self.dataset.lower() == '3dchairs': # # fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png # fixed_idx2 = 5172 # 3DChairs/images/14657_image_020_p020_t232_r096.png # fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png # # fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0] # fixed_img1 = fixed_img1.to(self.device).unsqueeze(0) # fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] # # fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0] # fixed_img2 = fixed_img2.to(self.device).unsqueeze(0) # fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] # # fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0] # fixed_img3 = fixed_img3.to(self.device).unsqueeze(0) # fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] # # Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2, # 'fixed_3':fixed_img_z3, 'random':random_img_zmu} # else: raise NotImplementedError # do traversal and collect generated images gifs = [] for key in Z: z_ori = Z[key] for row in range(self.z_dim): if loc != -1 and row != loc: continue z = z_ori.clone() for val in interpolation: z[:, row] = val sample = torch.sigmoid(decoder(z)).data gifs.append(sample) # save the generated files, also the animated gifs out_dir = os.path.join(self.output_dir_trvsl, str(iters)) mkdirs(self.output_dir_trvsl) mkdirs(out_dir) gifs = torch.cat(gifs) gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): I = torch.cat([IMG[key], gifs[i][j]], dim=0) save_image(tensor=I.cpu(), filename=os.path.join(out_dir, '%s_%03d.jpg' % (key, j)), nrow=1 + self.z_dim, pad_value=1) # make animated gif grid2gif(out_dir, key, str(os.path.join(out_dir, key + '.gif')), delay=10) self.set_mode(train=True) #### def viz_init(self): self.viz.close(env=self.name + '/lines', win=self.win_id['DZ']) self.viz.close(env=self.name + '/lines', win=self.win_id['recon']) self.viz.close(env=self.name + '/lines', win=self.win_id['kl']) self.viz.close(env=self.name + '/lines', win=self.win_id['rvS']) self.viz.close(env=self.name + '/lines', win=self.win_id['rvH']) if self.eval_metrics: self.viz.close(env=self.name + '/lines', win=self.win_id['metrics']) #### def visualize_line(self): # prepare data to plot data = self.line_gather.data iters = torch.Tensor(data['iter']) recon = torch.Tensor(data['recon']) kl = torch.Tensor(data['kl']) rvS = torch.Tensor(data['rvS']) rvH = torch.Tensor(data['rvH']) p_DZ = torch.Tensor(data['p_DZ']) p_DZ_perm = torch.Tensor(data['p_DZ_perm']) p_DZs = torch.stack([p_DZ, p_DZ_perm], -1) # (#items x 2) self.viz.line(X=iters, Y=p_DZs, env=self.name + '/lines', win=self.win_id['DZ'], update='append', opts=dict(xlabel='iter', ylabel='D(z)', title='Discriminator-Z', legend=[ 'D(z)', 'D(z_perm)', ])) self.viz.line(X=iters, Y=recon, env=self.name + '/lines', win=self.win_id['recon'], update='append', opts=dict(xlabel='iter', ylabel='recon loss', title='Reconstruction')) self.viz.line(X=iters, Y=kl, env=self.name + '/lines', win=self.win_id['kl'], update='append', opts=dict(xlabel='iter', ylabel='E_x[kl(q(z|x)||p(z)]', title='KL divergence')) self.viz.line(X=iters, Y=rvS, env=self.name + '/lines', win=self.win_id['rvS'], update='append', opts=dict(xlabel='iter', ylabel='||rv||_1', title='L1 norm of relevance vector')) self.viz.line(X=iters, Y=rvH, env=self.name + '/lines', win=self.win_id['rvH'], update='append', opts=dict(xlabel='iter', ylabel='H(rv)', title='Entropy of relevance vector')) #### def visualize_line_metrics(self, iters, metric1, metric2): # prepare data to plot iters = torch.tensor([iters], dtype=torch.int64).detach() metric1 = torch.tensor([metric1]) metric2 = torch.tensor([metric2]) metrics = torch.stack([metric1.detach(), metric2.detach()], -1) self.viz.line(X=iters, Y=metrics, env=self.name + '/lines', win=self.win_id['metrics'], update='append', opts=dict(xlabel='iter', ylabel='metrics', title='Disentanglement metrics', legend=['metric1', 'metric2'])) #### def set_mode(self, train=True): if train: self.encoder.train() self.decoder.train() self.D.train() else: self.encoder.eval() self.decoder.eval() self.D.eval() #### def save_checkpoint(self, iteration): encoder_path = os.path.join(self.ckpt_dir, 'iter_%s_encoder.pt' % iteration) decoder_path = os.path.join(self.ckpt_dir, 'iter_%s_decoder.pt' % iteration) rvec_path = os.path.join(self.ckpt_dir, 'iter_%s_rvec.pt' % iteration) D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % iteration) mkdirs(self.ckpt_dir) torch.save(self.encoder, encoder_path) torch.save(self.decoder, decoder_path) torch.save(self.rvec, rvec_path) torch.save(self.D, D_path) #### def load_checkpoint(self): encoder_path = os.path.join(self.ckpt_dir, 'iter_%s_encoder.pt' % self.ckpt_load_iter) decoder_path = os.path.join(self.ckpt_dir, 'iter_%s_decoder.pt' % self.ckpt_load_iter) rvec_path = os.path.join(self.ckpt_dir, 'iter_%s_rvec.pt' % self.ckpt_load_iter) D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % self.ckpt_load_iter) if self.use_cuda: self.encoder = torch.load(encoder_path) self.decoder = torch.load(decoder_path) self.rvec = torch.load(rvec_path) self.D = torch.load(D_path) else: self.encoder = torch.load(encoder_path, map_location='cpu') self.decoder = torch.load(decoder_path, map_location='cpu') self.rvec = torch.load(rvec_path, map_location='cpu') self.D = torch.load(D_path, map_location='cpu')
# clip critic weights between -0.01, 0.01 for p in critic.parameters(): p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP) # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)] gen_fake = critic(fake).reshape(-1) loss_gen = -torch.mean(gen_fake) gen.zero_grad() loss_gen.backward() opt_gen.step() # Print losses occasionally and print to tensorboard if batch_idx % 100 == 0 and batch_idx > 0: gen.eval() critic.eval() print( f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \ Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}") with torch.no_grad(): fake = gen(noise) # take out (up to) 32 examples img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True) img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True) writer_real.add_image("Real", img_grid_real, global_step=step) writer_fake.add_image("Fake", img_grid_fake, global_step=step)
def train(resume_path=None, jigsaw_path=None): writer = SummaryWriter('../runs/'+hparams.exp_name) for k in hparams.__dict__.keys(): writer.add_text(str(k), str(hparams.__dict__[k])) train_dataset = ChestData(data_csv=hparams.train_csv, data_dir=hparams.train_dir, augment=hparams.augment, transform=transforms.Compose([ transforms.Resize(hparams.image_shape), transforms.ToTensor(), transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915)) ])) validation_dataset = ChestData(data_csv=hparams.valid_csv, data_dir=hparams.valid_dir, transform=transforms.Compose([ transforms.Resize(hparams.image_shape), transforms.ToTensor(), transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915)) ])) # train_sampler = WeightedRandomSampler() train_loader = DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=2) validation_loader = DataLoader(validation_dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=2) print('loaded train data of length : {}'.format(len(train_dataset))) adversarial_loss = torch.nn.BCELoss().to(hparams.gpu_device) discriminator = Discriminator().to(hparams.gpu_device) # if hparams.cuda: # discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids) params_count = 0 for param in discriminator.parameters(): params_count += np.prod(param.size()) print('Model has {0} trainable parameters'.format(params_count)) if not hparams.pretrained: # discriminator.apply(weights_init_normal) pass # if jigsaw_path: # jigsaw = Jigsaw().to(hparams.gpu_device) # if hparams.cuda: # jigsaw = nn.DataParallel(jigsaw, device_ids=hparams.device_ids) # checkpoints = torch.load(jigsaw_path, map_location=hparams.gpu_device) # jigsaw.load_state_dict(checkpoints['discriminator_state_dict']) # discriminator.module.model.features = jigsaw.module.feature.features # print('loaded pretrained feature extractor from {} ..'.format(jigsaw_path)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=hparams.learning_rate, betas=(0.9, 0.999)) scheduler_D = ReduceLROnPlateau(optimizer_D, mode='min', factor=0.3, patience=0, verbose=True, cooldown=0) Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor def validation(discriminator_, send_stats=False, epoch=0): print('Validating model on {0} examples. '.format(len(validation_dataset))) with torch.no_grad(): pred_logits_list = [] labels_list = [] for (img, labels, imgs_names) in tqdm(validation_loader): img = Variable(img.float(), requires_grad=False) labels = Variable(labels.float(), requires_grad=False) img_ = img.to(hparams.gpu_device) labels = labels.to(hparams.gpu_device) pred_logits = discriminator_(img_) pred_logits_list.append(pred_logits) labels_list.append(labels) pred_logits = torch.cat(pred_logits_list, dim=0) labels = torch.cat(labels_list, dim=0) val_loss = adversarial_loss(pred_logits, labels) return accuracy_metrics(labels.long(), pred_logits), val_loss print('Starting training.. (log saved in:{})'.format(hparams.exp_name)) start_time = time.time() best_valid_auc = 0 # print(model) for epoch in range(hparams.num_epochs): for batch, (imgs, labels, imgs_name) in enumerate(tqdm(train_loader)): imgs = Variable(imgs.float(), requires_grad=False) labels = Variable(labels.float(), requires_grad=False) imgs_ = imgs.to(hparams.gpu_device) labels = labels.to(hparams.gpu_device) # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() pred_logits, aux_logits = discriminator(imgs_) d_loss1 = adversarial_loss(pred_logits, labels) d_loss2 = adversarial_loss(aux_logits, labels) d_loss = d_loss1 + 0.4 * d_loss2 d_loss.backward() optimizer_D.step() writer.add_scalar('d_loss', d_loss.item(), global_step=batch+epoch*len(train_loader)) pred_labels = (pred_logits >= hparams.thresh) pred_labels = pred_labels.float() # if batch % hparams.print_interval == 0: # auc, f1, acc, _, _ = accuracy_metrics(pred_labels, labels.long(), pred_logits) # print('[Epoch - {0:.1f}, batch - {1:.3f}, d_loss - {2:.6f}, acc - {3:.4f}, f1 - {4:.5f}, auc - {5:.4f}]'.\ # format(1.0*epoch, 100.0*batch/len(train_loader), d_loss.item(), acc['avg'], f1[hparams.avg_mode], auc[hparams.avg_mode])) (val_auc, val_f1, val_acc, val_conf_mat, best_thresh), val_loss = validation(discriminator.eval(), epoch=epoch) discriminator = discriminator.train() for lbl in range(hparams.num_classes): fig = plot_cf(val_conf_mat[lbl]) writer.add_figure('val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch) plt.close(fig) writer.add_scalar('val_f1_{}'.format(hparams.id_to_class[lbl]), val_f1[lbl], global_step=epoch) writer.add_scalar('val_auc_{}'.format(hparams.id_to_class[lbl]), val_auc[lbl], global_step=epoch) writer.add_scalar('val_acc_{}'.format(hparams.id_to_class[lbl]), val_acc[lbl], global_step=epoch) writer.add_scalar('val_f1_{}'.format('micro'), val_f1['micro'], global_step=epoch) writer.add_scalar('val_auc_{}'.format('micro'), val_auc['micro'], global_step=epoch) writer.add_scalar('val_f1_{}'.format('macro'), val_f1['macro'], global_step=epoch) writer.add_scalar('val_auc_{}'.format('macro'), val_auc['macro'], global_step=epoch) writer.add_scalar('val_loss', val_loss, global_step=epoch) writer.add_scalar('val_f1', val_f1[hparams.avg_mode], global_step=epoch) writer.add_scalar('val_auc', val_auc[hparams.avg_mode], global_step=epoch) writer.add_scalar('val_acc', val_acc['avg'], global_step=epoch) scheduler_D.step(val_loss) writer.add_scalar('learning_rate', optimizer_D.param_groups[0]['lr'], global_step=epoch) torch.save({ 'epoch': epoch, 'discriminator_state_dict': discriminator.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), }, hparams.model+'.'+str(epoch)) if best_valid_auc <= val_auc[hparams.avg_mode]: best_valid_auc = val_auc[hparams.avg_mode] for lbl in range(hparams.num_classes): fig = plot_cf(val_conf_mat[lbl]) writer.add_figure('best_val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch) plt.close(fig) torch.save({ 'epoch': epoch, 'discriminator_state_dict': discriminator.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), }, hparams.model+'.best') print('best model on validation set saved.') print('[Epoch - {0:.1f} ---> val_auc - {1:.4f}, current_lr - {2:.6f}, val_loss - {3:.4f}, best_val_auc - {4:.4f}, val_acc - {5:.4f}, val_f1 - {6:.4f}] - time - {7:.1f}'\ .format(1.0*epoch, val_auc[hparams.avg_mode], optimizer_D.param_groups[0]['lr'], val_loss, best_valid_auc, val_acc['avg'], val_f1[hparams.avg_mode], time.time()-start_time)) start_time = time.time()
class BiAAE(object): def __init__(self, params): self.params = params self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang, params.tgt_lang, params.norm_embeddings) self.tune_best_dir = "{}/best".format(self.tune_dir) self.X_AE = AE(params) self.Y_AE = AE(params) self.D_X = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) self.D_Y = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) self.nets = [self.X_AE, self.Y_AE, self.D_X, self.D_Y] self.loss_fn = torch.nn.BCELoss() self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6) def weights_init(self, m): # 正交初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal(m.weight) if m.bias is not None: torch.nn.init.constant(m.bias, 0.01) def weights_init2(self, m): # xavier_normal 初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal(m.weight) if m.bias is not None: torch.nn.init.constant(m.bias, 0.01) def weights_init3(self, m): # 单位阵初始化 if isinstance(m, torch.nn.Linear): m.weight.data.copy_( torch.diag(torch.ones(self.params.g_input_size))) def freeze(self, m): for p in m.parameters(): p.requires_grad = False def defreeze(self, m): for p in m.parameters(): p.requires_grad = True def init_state(self, seed=-1): if torch.cuda.is_available(): # Move the network and the optimizer to the GPU for net in self.nets: net.cuda() self.loss_fn = self.loss_fn.cuda() self.loss_fn2 = self.loss_fn2.cuda() print('Init3 the model...') self.X_AE.apply(self.weights_init) # 可更改G初始化方式 self.Y_AE.apply(self.weights_init) # 可更改G初始化方式 self.D_X.apply(self.weights_init2) #print(self.D_X.map1.weight) self.D_Y.apply(self.weights_init2) def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed): # Load data if not os.path.exists(self.params.data_dir): print("Data path doesn't exists: %s" % self.params.data_dir) if not os.path.exists(self.tune_dir): os.makedirs(self.tune_dir) if not os.path.exists(self.tune_best_dir): os.makedirs(self.tune_best_dir) src_word2id = src_dico[1] tgt_word2id = tgt_dico[1] en = src_emb it = tgt_emb #eval = Evaluator(self.params, en,it, torch.cuda.is_available()) AE_optimizer = optim.SGD(filter( lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())), lr=self.params.g_learning_rate) D_optimizer = optim.SGD(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=self.params.d_learning_rate) D_A_acc_epochs = [] D_B_acc_epochs = [] D_A_loss_epochs = [] D_B_loss_epochs = [] d_loss_epochs = [] G_AB_loss_epochs = [] G_BA_loss_epochs = [] G_AB_recon_epochs = [] G_BA_recon_epochs = [] g_loss_epochs = [] L_Z_loss_epoches = [] acc_epochs = [] criterion_epochs = [] best_valid_metric = -100 try: for epoch in range(self.params.num_epochs): D_A_losses = [] D_B_losses = [] G_AB_losses = [] G_AB_recon = [] G_BA_losses = [] G_adv_losses = [] G_BA_recon = [] L_Z_losses = [] d_losses = [] g_losses = [] hit_A = 0 hit_B = 0 total = 0 start_time = timer() # lowest_loss = 1e5 label_D = to_variable( torch.FloatTensor(2 * self.params.mini_batch_size).zero_()) label_D[:self.params. mini_batch_size] = 1 - self.params.smoothing label_D[self.params.mini_batch_size:] = self.params.smoothing label_G = to_variable( torch.FloatTensor(self.params.mini_batch_size).zero_()) label_G = label_G + 1 - self.params.smoothing for mini_batch in range( 0, self.params.iters_in_epoch // self.params.mini_batch_size): for d_index in range(self.params.d_steps): D_optimizer.zero_grad() # Reset the gradients self.D_X.train() self.D_Y.train() view_X, view_Y = self.get_batch_data_fast(en, it) # Discriminator X Y_Z = self.Y_AE.encode(view_Y).detach() fake_X = self.X_AE.decode(Y_Z).detach() input = torch.cat([view_X, fake_X], 0) pred_A = self.D_X(input) D_A_loss = self.loss_fn(pred_A, label_D) # Discriminator Y X_Z = self.X_AE.encode(view_X).detach() fake_Y = self.Y_AE.decode(X_Z).detach() input = torch.cat([view_Y, fake_Y], 0) pred_B = self.D_Y(input) D_B_loss = self.loss_fn(pred_B, label_D) D_loss = D_A_loss + self.params.gate * D_B_loss D_loss.backward( ) # compute/store gradients, but don't change params d_losses.append(to_numpy(D_loss.data)) D_A_losses.append(to_numpy(D_A_loss.data)) D_B_losses.append(to_numpy(D_B_loss.data)) discriminator_decision_A = to_numpy(pred_A.data) hit_A += np.sum( discriminator_decision_A[:self.params. mini_batch_size] >= 0.5) hit_A += np.sum( discriminator_decision_A[self.params. mini_batch_size:] < 0.5) discriminator_decision_B = to_numpy(pred_B.data) hit_B += np.sum( discriminator_decision_B[:self.params. mini_batch_size] >= 0.5) hit_B += np.sum( discriminator_decision_B[self.params. mini_batch_size:] < 0.5) D_optimizer.step( ) # Only optimizes D's parameters; changes based on stored gradients from backward() # Clip weights #_clip(self.D_X, self.params.clip_value) #_clip(self.D_Y, self.params.clip_value) sys.stdout.write( "[%d/%d] :: Discriminator Loss: %.3f \r" % (mini_batch, self.params.iters_in_epoch // self.params.mini_batch_size, np.asscalar(np.mean(d_losses)))) sys.stdout.flush() total += 2 * self.params.mini_batch_size * self.params.d_steps for g_index in range(self.params.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) AE_optimizer.zero_grad() self.D_X.eval() self.D_Y.eval() view_X, view_Y = self.get_batch_data_fast(en, it) # Generator X_AE ## adversarial loss X_Z = self.X_AE.encode(view_X) X_recon = self.X_AE.decode(X_Z) Y_fake = self.Y_AE.decode(X_Z) pred_Y = self.D_Y(Y_fake) L_adv_X = self.loss_fn(pred_Y, label_G) L_recon_X = 1.0 - torch.mean( self.loss_fn2(view_X, X_recon)) # Generator Y_AE # adversarial loss Y_Z = self.Y_AE.encode(view_Y) Y_recon = self.Y_AE.decode(Y_Z) X_fake = self.X_AE.decode(Y_Z) pred_X = self.D_X(X_fake) L_adv_Y = self.loss_fn(pred_X, label_G) ### autoAE Loss L_recon_Y = 1.0 - torch.mean( self.loss_fn2(view_Y, Y_recon)) # cross-lingual Loss L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z)) G_loss = self.params.adv_weight * (self.params.gate*L_adv_X + L_adv_Y) + \ self.params.mono_weight * (L_recon_X+L_recon_Y) + \ self.params.cross_weight * L_Z G_loss.backward() g_losses.append(to_numpy(G_loss.data)) G_AB_losses.append(to_numpy(L_adv_X.data)) G_BA_losses.append(to_numpy(L_adv_Y.data)) G_adv_losses.append( to_numpy(L_adv_Y.data + L_adv_X.data)) G_AB_recon.append(to_numpy(L_recon_X.data)) G_BA_recon.append(to_numpy(L_recon_Y.data)) L_Z_losses.append(to_numpy(L_Z.data)) AE_optimizer.step() # Only optimizes G's parameters sys.stdout.write( "[%d/%d] :: Generator Loss: %.3f \r" % (mini_batch, self.params.iters_in_epoch // self.params.mini_batch_size, np.asscalar(np.mean(g_losses)))) sys.stdout.flush() '''for each epoch''' D_A_acc_epochs.append(hit_A / total) D_B_acc_epochs.append(hit_B / total) G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses))) G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses))) D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses))) D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses))) G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon))) G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon))) L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses))) d_loss_epochs.append(np.asscalar(np.mean(d_losses))) g_loss_epochs.append(np.asscalar(np.mean(g_losses))) print( "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins" .format(epoch, np.asscalar(np.mean(d_losses)), 0.5 * (hit_A + hit_B) / total, np.asscalar(np.mean(g_losses)), (timer() - start_time) / 60)) if (epoch + 1) % self.params.print_every == 0: # No need for discriminator weights X_Z = self.X_AE.encode(Variable(en)).data Y_Z = self.Y_AE.encode(Variable(it)).data mstart_time = timer() for method in [self.params.eval_method]: results = get_word_translation_accuracy( self.params.src_lang, src_word2id, X_Z, self.params.tgt_lang, tgt_word2id, Y_Z, method=method, dico_eval=self.params.eval_file) acc1 = results[0][1] print('{} takes {:.2f}s'.format(method, timer() - mstart_time)) print('Method:{} score:{:.4f}'.format(method, acc1)) csls, size = dist_mean_cosine(self.params, X_Z, Y_Z) criterion = size if criterion > best_valid_metric: print("New criterion value: {}".format(criterion)) best_valid_metric = criterion fp = open( self.tune_best_dir + "/seed_{}_dico_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp". format(seed, self.params.dico_build, self.params.gate, epoch, acc1), 'w') fp.close() torch.save( self.X_AE.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_gate_{}_best_X.t7'.format( seed, self.params.dico_build, self.params.gate)) torch.save( self.Y_AE.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_gate_{}_best_Y.t7'.format( seed, self.params.dico_build, self.params.gate)) torch.save( self.D_X.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_gate_{}_best_Dx.t7'.format( seed, self.params.dico_build, self.params.gate)) torch.save( self.D_Y.state_dict(), self.tune_best_dir + '/seed_{}_dico_{}_gate_{}__best_Dy.t7'.format( seed, self.params.dico_build, self.params.gate)) # Saving generator weights fp = open( self.tune_dir + "/seed_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp".format( seed, self.params.gate, epoch, acc1), 'w') fp.close() acc_epochs.append(acc1) criterion_epochs.append(criterion) criterion_fb, epoch_fb = max([ (score, index) for index, score in enumerate(criterion_epochs) ]) fp = open( self.tune_best_dir + "/seed_{}_dico_{}_gate_{}_epoch_{}_Acc_{:.3f}_{:.4f}.cslsfb". format(seed, self.params.gate, self.params.dico_build, epoch_fb, acc_epochs[epoch_fb], criterion_fb), 'w') fp.close() # Save the plot for discriminator accuracy and generator loss fig = plt.figure() plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A') plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B') plt.ylabel('D_accuracy') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A') plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B') plt.ylabel('D_losses') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB') plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA') plt.ylabel('G_losses') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB') plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA') plt.ylabel('G_recon_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_G_Recon.png'.format(seed)) # fig = plt.figure() # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z') # plt.ylabel('L_Z_loss') # plt.xlabel('epochs') # plt.legend() # fig.savefig(tune_dir + '/seed_{}_L_Z.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(acc_epochs)), acc_epochs, color='b', label='trans_acc1') plt.ylabel('trans_acc') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed)) ''' fig = plt.figure() plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls') plt.ylabel('csls') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed)) ''' fig = plt.figure() plt.plot(range(0, len(g_loss_epochs)), g_loss_epochs, color='b', label='G_loss') plt.ylabel('g_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(d_loss_epochs)), d_loss_epochs, color='b', label='csls') plt.ylabel('D_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed)) plt.close('all') except KeyboardInterrupt: print("Interrupted.. saving model !!!") torch.save(self.X_AE.state_dict(), self.tune_dir + '/X_AE_model_interrupt.t7') torch.save(self.Y_AE.state_dict(), self.tune_dir + '/Y_AE_model_interrupt.t7') torch.save(self.D_X.state_dict(), self.tune_dir + '/D_X_model_interrupt.t7') torch.save(self.D_Y.state_dict(), self.tune_dir + '/D_y_model_interrupt.t7') exit() return def get_batch_data_fast(self, emb_en, emb_it): params = self.params random_en_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) random_it_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) en_batch = to_variable(emb_en)[random_en_indices.cuda()] it_batch = to_variable(emb_it)[random_it_indices.cuda()] return en_batch, it_batch
# reset the gradients to avoid gradients accumulation discriminator_optim.zero_grad() # compute the gradients of loss w.r.t weights discriminator_loss.backward(retain_graph=True) # update the weights discriminator_optim.step() # Store the loss for later use train_history['discriminator_loss'].append( discriminator_loss.item()) # Train generator # create fake labels trick = torch.tensor( np.array([1] * noise_size), dtype=torch.float32).unsqueeze(dim=1).to(device) discriminator.eval() # freeze the discriminator if stop == 0: generator.train() # enable training mode for the generator generator_loss = generator_criterion( discriminator(generated_data), trick) generator_optim.zero_grad() generator_loss.backward(retain_graph=True) generator_optim.step() train_history['generator_loss'].append(generator_loss.item()) else: generator.eval() # enable evaluation mode generator_loss = generator_criterion( discriminator(generated_data), trick) train_history['generator_loss'].append(generator_loss.item()) # unfreeze the discriminator's layers
def main(): parser = argparse.ArgumentParser() parser.add_argument("--G_path", help="Generator mdoel path") parser.add_argument("--D_path", help="Discriminator mdoel path") parser.add_argument("--dir_path", help="path to load LR images and store SR images") parser.add_argument("--batch_size", type=int, help="Batch size") parser.add_argument("--res_blocks", type=int, help="No. of resnet blocks") parser.add_argument("--in_channels", type=int, help="No. of input channels") parser.add_argument("--train", type=int, help="Train - 1 or Test - 0") parser.add_argument("--downsample", nargs='?', const=True, default=False, help="Downsampling GAN") args = parser.parse_args() pathG = args.G_path pathD = args.D_path srDir = args.dir_path test_batch_size = args.batch_size shuffle_dataset = True random_seed = 42 root = srDir dataset = TDF(root, 25, 2, args.downsample, args.train) dataset_size = len(dataset) print(dataset_size) indices = list(range(dataset_size)) if shuffle_dataset: np.random.seed(random_seed) np.random.shuffle(indices) indices_test = indices print(len(indices_test)) # Creating PT data samplers and loaders: test_sampler = SubsetRandomSampler(indices_test) test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size, sampler=test_sampler) # Num batches num_batches = len(test_loader) print(num_batches) G = Generator(args.in_channels, 2, args.res_blocks, args.downsample) G.load_state_dict(torch.load(pathG)) G.eval() D = Discriminator(args.in_channels) D.load_state_dict(torch.load(pathD)) D.eval() G.cuda() D.cuda() coords = ["0", "0", "0"] with torch.no_grad(): for index, (lr, hr, filName) in enumerate(test_loader): lr = lr.float() val_z = Variable(lr) val_z = val_z.cuda() sr_test = G(val_z) val_target = Variable(hr) val_target = val_target.cuda() hr_test = D(val_target).mean() hr_fake = D(sr_test).mean() utils.write_voxels(args.batch_size, srDir, sr_test, index, args.downsample, "test", coords, filName) if (index) % 50 == 0: print(index) print(torch.cuda.memory_allocated())
def test(model_path, data=(hparams.valid_csv, hparams.dev_file), plot_auc='valid', plot_path=hparams.result_dir + 'valid', best_thresh=None): test_dataset = AudioData(data_csv=data[0], data_file=data[1], ds_type='valid', augment=True, transform=transforms.Compose([ transforms.ToTensor(), ])) test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=2) discriminator = Discriminator().to(hparams.gpu_device) if hparams.cuda: discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids) checkpoint = torch.load(model_path, map_location=hparams.gpu_device) discriminator.load_state_dict(checkpoint['discriminator_state_dict']) discriminator = discriminator.eval() # print('Model loaded') Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor print('Testing model on {0} examples. '.format(len(test_dataset))) with torch.no_grad(): pred_logits_list = [] labels_list = [] img_names_list = [] # for _ in range(hparams.repeat_infer): for (inp, labels, img_names) in tqdm(test_loader): inp = Variable(inp.float(), requires_grad=False) labels = Variable(labels.long(), requires_grad=False) inp = inp.to(hparams.gpu_device) labels = labels.to(hparams.gpu_device) if hparams.dim3: inp = inp.view(-1, 1, 640, 64) inp = torch.cat([inp] * 3, dim=1) pred_logits = discriminator(inp) pred_logits_list.append(pred_logits) labels_list.append(labels) img_names_list.append(img_names) pred_logits = torch.cat(pred_logits_list, dim=0) labels = torch.cat(labels_list, dim=0) auc, f1, acc, conf_mat = accuracy_metrics(labels, pred_logits, plot_auc=plot_auc, plot_path=plot_path, best_thresh=best_thresh) fig = plot_cf(conf_mat) plt.savefig(hparams.result_dir + 'test_conf_mat.png') res = ' -- avg_acc - {0:.4f}'.format(acc['avg']) for it in range(10): res += ', acc_{}'.format( hparams.id_to_class[it]) + ' - {0:.4f}'.format(acc[it]) print('== Test on -- ' + model_path + res) # print('== Test on -- '+model_path+' == \n\ # auc_{0} - {10:.4f}, auc_{1} - {11:.4f}, auc_{2} - {12:.4f}, auc_{3} - {13:.4f}, auc_{4} - {14:.4f}, auc_{5} - {15:.4f}, auc_{6} - {16:.4f}, auc_{7} - {17:.4f}, auc_{8} - {18:.4f}, auc_{9} - {19:.4f}, auc_micro - {20:.4f}, auc_macro - {21:.4f},\n\ # acc_{0} - {22:.4f}, acc_{1} - {23:.4f}, acc_{2} - {24:.4f}, acc_{3} - {25:.4f}, acc_{4} - {26:.4f}, acc_{5} - {27:.4f}, acc_{6} - {28:.4f}, acc_{7} - {29:.4f}, acc_{8} - {30:.4f}, acc_{9} - {31:.4f}, acc_avg - {32:.4f},\n\ # f1_{0} - {33:.4f}, f1_{1} - {34:.4f}, f1_{2} - {35:.4f}, f1_{3} - {36:.4f}, f1_{4} - {37:.4f}, f1_{5} - {38:.4f}, f1_{6} - {39:.4f}, f1_{7} - {40:.4f}, f1_{8} - {41:.4f}, f1_{9} - {42:.4f}, f1_micro - {42:.4f}, f1_macro - {43:.4f}, =='.\ # format([hparams.id_to_class[it] for it in range(10)]+[auc[it] for it in range(10)]+[auc['micro'], auc['macro']]+[acc[it] for it in range(10)]+[acc['avg']]+[f1[it] for it in range(10)]+[f1['micro'], f1['macro']])) return acc['avg']
'Train density Loss: {:.4f} Density Adversarial Loss: {:.4f} Discriminator Loss: {:.4f}' .format(loss_dens_value / iter_count, loss_adv_value / iter_count, loss_D_value / iter_count)) logger.scalar_summary('Temporal/train_density_loss', loss_dens_value, epoch) logger.scalar_summary('Temporal/train_adv_loss', loss_adv_value, epoch) logger.scalar_summary('Temporal/train_D_loss', loss_D_value, epoch) #test mae & mse on train set epoch_mae = running_mae / totalnum epoch_mse = np.sqrt(running_mse / totalnum) print('Training Iteration:{} MAE: {:.4f} MSE: {:.4f}'.format( epoch, epoch_mae, epoch_mse)) # 验证阶段 net.eval() net_D.eval() running_loss = 0.0 running_mse = 0.0 running_mae = 0.0 totalnum = 0 for idx, (image, densityMap) in enumerate(val_loader): image = image.to(device) densityMap = densityMap.to(device) optimizer.zero_grad() duration = time.time() predDensityMap = net(image) outputs_np = predDensityMap.data.cpu().numpy() densityMap_np = densityMap.data.cpu().numpy()
class Solver(object): def __init__(self, args): # model self.g_optimizer = None self.d_optimizer = None self.generator = None self.discriminator = None self.MSELoss = None self.L1loss = None self.GPU_IN_USE = torch.cuda.is_available() self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu') # Training settings self.dataset = args.dataset self.num_epochs = args.num_epochs self.batch_size = args.batch_size self.threads = args.threads self.g_conv_dim = args.g_conv_dim self.d_conv_dim = args.d_conv_dim self.in_channel = args.in_channel self.out_channel = args.out_channel self.use_sigmoid = False # hyper-parameters self.lr = args.lr self.beta_1 = args.beta_1 self.lamb = args.lamb # dataloader self.training_data_loader = None self.testing_data_loader = None def build_model(self): self.generator = Generator(in_channel=self.in_channel, out_channel=self.out_channel, g_conv_dim=self.g_conv_dim, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9).to(self.device) self.generator.normal_init() self.discriminator = Discriminator( in_channel=self.in_channel + self.out_channel, d_conv_dim=self.d_conv_dim, num_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=self.use_sigmoid).to(self.device) self.discriminator.normal_init() self.MSELoss = nn.MSELoss() self.L1loss = nn.L1Loss() if self.GPU_IN_USE: self.MSELoss.cuda() self.L1loss.cuda() cudnn.benchmark = True self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta_1, 0.999)) self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta_1, 0.999)) def build_dataset(self): root_path = "datasets/" train_set = get_training_set(root_path + self.dataset) test_set = get_test_set(root_path + self.dataset) self.training_data_loader = DataLoader(dataset=train_set, num_workers=self.threads, batch_size=self.batch_size, shuffle=True) self.testing_data_loader = DataLoader(dataset=test_set, num_workers=self.threads, batch_size=self.batch_size, shuffle=False) @staticmethod def to_data(x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def reset_grad(self): """Zero the gradient buffers.""" self.d_optimizer.zero_grad() self.g_optimizer.zero_grad() @staticmethod def de_normalize(x): """Convert range (-1, 1) to (0, 1)""" out = (x + 1) / 2 return out.clamp(0, 1) def checkpoint(self, epoch): if not os.path.exists("checkpoint"): os.mkdir("checkpoint") if not os.path.exists(os.path.join("checkpoint", self.dataset)): os.mkdir(os.path.join("checkpoint", self.dataset)) net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format( self.dataset, epoch) net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format( self.dataset, epoch) torch.save(self.generator, net_g_model_out_path) torch.save(self.discriminator, net_d_model_out_path) print("Checkpoint saved to {}".format("checkpoint" + self.dataset)) def mode_switch(self, mode): if mode == 'train': self.discriminator.train() self.generator.train() elif mode == 'eval': self.discriminator.eval() self.generator.eval() def train(self): self.mode_switch('train') for i, (data, target) in enumerate(self.training_data_loader): # forward data, target = data.to(self.device), target.to(self.device) fake_target = self.generator(data) ########################### # (1) train D network: maximize log(D(x,y)) + log(1 - D(x,G(x))) ########################### self.reset_grad() # train with fake fake_combined = torch.cat((data, fake_target), 1) fake_prediction = self.discriminator(fake_combined.detach()) fake_d_loss = self.MSELoss( fake_prediction, torch.zeros(1, 1, fake_prediction.size(2), fake_prediction.size(3), device=self.device)) # train with real real_combined = torch.cat((data, target), 1) real_prediction = self.discriminator(real_combined) real_d_loss = self.MSELoss( real_prediction, torch.ones(1, 1, real_prediction.size(2), real_prediction.size(3), device=self.device)) # Combined loss loss_d = (fake_d_loss + real_d_loss) * 0.5 loss_d.backward() self.d_optimizer.step() ########################## # (2) train G network: maximize log(D(x,G(x))) + L1(y,G(x)) ########################## self.reset_grad() # First, G(A) should fake the discriminator fake_combined = torch.cat((data, fake_target), 1) fake_prediction = self.discriminator(fake_combined) g_loss_mse = self.MSELoss( fake_prediction, torch.ones(1, 1, fake_prediction.size(2), fake_prediction.size(3), device=self.device)) # Second, G(A) = B g_loss_l1 = self.L1loss(fake_target, target) * self.lamb loss_g = g_loss_mse + g_loss_l1 loss_g.backward() self.g_optimizer.step() print("({}/{}): Loss_D: {:.4f} Loss_G: {:.4f}".format( i, len(self.training_data_loader), loss_d.item(), loss_g.item())) def test(self): self.mode_switch('eval') avg_psnr = 0 with torch.no_grad(): for (data, target) in self.testing_data_loader: data, target = data.to(self.device), target.to(self.device) prediction = self.generator(data) mse = self.MSELoss(prediction, target) psnr = 10 * log10(1 / mse.data[0]) avg_psnr += psnr print("===> Avg. PSNR: {:.4f} dB".format( avg_psnr / len(self.testing_data_loader))) def run(self): self.build_model() self.build_dataset() for e in range(1, self.num_epochs + 1): print("===> Epoch {}/{}".format(e, self.num_epochs)) self.train() self.checkpoint(e) self.test()
class CycleBWE(object): def __init__(self, params): self.params = params self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang, params.tgt_lang, params.norm_embeddings) self.tune_best_dir = "{}/best".format(self.tune_dir) self.tune_export_dir = "{}/export".format(self.tune_dir) if self.params.eval_file == 'wiki': self.eval_file = '../data/bilingual_dicts/{}-{}.5000-6500.txt'.format( self.params.src_lang, self.params.tgt_lang) self.eval_file2 = '../data/bilingual_dicts/{}-{}.5000-6500.txt'.format( self.params.tgt_lang, self.params.src_lang) elif self.params.eval_file == 'wacky': self.eval_file = '../data/bilingual_dicts/{}-{}.test.txt'.format( self.params.src_lang, self.params.tgt_lang) self.eval_file2 = '../data/bilingual_dicts/{}-{}.test.txt'.format( self.params.tgt_lang, self.params.src_lang) else: print('Invalid eval file!') # self.seed = random.randint(0, 1000) # self.seed = 41 # self.initialize_exp(self.seed) self.X_AE = AE(params) self.Y_AE = AE(params) self.D_X = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) self.D_Y = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) self.nets = [self.X_AE, self.Y_AE, self.D_X, self.D_Y] self.loss_fn = torch.nn.BCELoss() self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6) def weights_init(self, m): # 正交初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal(m.weight) if m.bias is not None: torch.nn.init.constant(m.bias, 0.01) def weights_init2(self, m): # xavier_normal 初始化 if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal_(m.weight) if m.bias is not None: torch.nn.init.constant_(m.bias, 0.01) def weights_init3(self, m): # 单位阵初始化 if isinstance(m, torch.nn.Linear): m.weight.data.copy_( torch.diag(torch.ones(self.params.g_input_size))) def init_state(self, state=1): if torch.cuda.is_available(): # Move the network and the optimizer to the GPU for net in self.nets: net.cuda() self.loss_fn = self.loss_fn.cuda() self.loss_fn2 = self.loss_fn2.cuda() if self.params.init == 'eye': self.X_AE.apply(self.weights_init3) # 可更改G初始化方式 self.Y_AE.apply(self.weights_init3) # 可更改G初始化方式 elif self.params.init == 'orth': self.X_AE.apply(self.weights_init) # 可更改G初始化方式 self.Y_AE.apply(self.weights_init) else: print('Invalid init func!') #self.D_X.apply(self.weights_init2) #self.D_Y.apply(self.weights_init2) def orthogonalize(self, W): params = self.params W.copy_((1 + params.beta) * W - params.beta * W.mm(W.transpose(0, 1).mm(W))) def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed): params = self.params # Load data if not os.path.exists(params.data_dir): print("Data path doesn't exists: %s" % params.data_dir) if not os.path.exists(self.tune_dir): os.makedirs(self.tune_dir) if not os.path.exists(self.tune_best_dir): os.makedirs(self.tune_best_dir) if not os.path.exists(self.tune_export_dir): os.makedirs(self.tune_export_dir) src_word2id = src_dico[1] tgt_word2id = tgt_dico[1] en = src_emb it = tgt_emb params = _get_eval_params(params) self.params = params eval = Evaluator(params, en, it, torch.cuda.is_available()) # for seed_index in range(params.num_random_seeds): AE_optimizer = optim.SGD(filter( lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())), lr=params.g_learning_rate) # AE_optimizer = optim.SGD(G_params, lr=0.1, momentum=0.9) # AE_optimizer = optim.Adam(G_params, lr=params.g_learning_rate, betas=(0.9, 0.9)) # AE_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),lr=params.g_learning_rate,alpha=0.9) D_optimizer = optim.SGD(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate) # D_optimizer = optim.Adam(D_params, lr=params.d_learning_rate, betas=(0.5, 0.9)) # D_optimizer = optim.RMSprop(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate , alpha=0.9) # D_X=nn.DataParallel(D_X) # D_Y=nn.DataParallel(D_Y) # true_dict = get_true_dict(params.data_dir) D_A_acc_epochs = [] D_B_acc_epochs = [] D_A_loss_epochs = [] D_B_loss_epochs = [] G_AB_loss_epochs = [] G_BA_loss_epochs = [] G_AB_recon_epochs = [] G_BA_recon_epochs = [] L_Z_loss_epoches = [] acc1_epochs = [] acc2_epochs = [] csls_epochs = [] f_csls_epochs = [] b_csls_epochs = [] best_valid_metric = -100 # logs for plotting later log_file = open( "log_src_tgt.txt", "w") # Being overwritten in every loop, not really required log_file.write("epoch, dis_loss, dis_acc, g_loss\n") try: for epoch in range(self.params.num_epochs): D_A_losses = [] D_B_losses = [] G_AB_losses = [] G_AB_recon = [] G_BA_losses = [] G_adv_losses = [] G_BA_recon = [] L_Z_losses = [] d_losses = [] g_losses = [] hit_A = 0 hit_B = 0 total = 0 start_time = timer() # lowest_loss = 1e5 # label_D = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_()) label_D = to_variable( torch.FloatTensor(2 * params.mini_batch_size).zero_()) label_D[:params.mini_batch_size] = 1 - params.smoothing label_D[params.mini_batch_size:] = params.smoothing label_G = to_variable( torch.FloatTensor(params.mini_batch_size).zero_()) label_G = label_G + 1 - params.smoothing for mini_batch in range( 0, params.iters_in_epoch // params.mini_batch_size): for d_index in range(params.d_steps): D_optimizer.zero_grad() # Reset the gradients self.D_X.train() self.D_Y.train() #print('D_X:', self.D_X.map1.weight.data) #print('D_Y:', self.D_Y.map1.weight.data) view_X, view_Y = self.get_batch_data_fast_new(en, it) # Discriminator X #print('View_Y',view_Y) fake_X = self.Y_AE.encode(view_Y).detach() #print('fakeX',fake_X) input = torch.cat([view_X, fake_X], 0) pred_A = self.D_X(input) #print('Pred_A',pred_A) D_A_loss = self.loss_fn(pred_A, label_D) # print(view_Y) # Discriminator Y # print('View_X',view_X) fake_Y = self.X_AE.encode(view_X).detach() # print('fakeY:',fake_Y) input = torch.cat([view_Y, fake_Y], 0) pred_B = self.D_Y(input) # print('Pred_B', pred_B) D_B_loss = self.loss_fn(pred_B, label_D) D_loss = (1.0) * D_A_loss + params.gate * D_B_loss D_loss.backward( ) # compute/store gradients, but don't change params d_losses.append(to_numpy(D_loss.data)) D_A_losses.append(to_numpy(D_A_loss.data)) D_B_losses.append(to_numpy(D_B_loss.data)) discriminator_decision_A = to_numpy(pred_A.data) hit_A += np.sum( discriminator_decision_A[:params.mini_batch_size] >= 0.5) hit_A += np.sum( discriminator_decision_A[params.mini_batch_size:] < 0.5) discriminator_decision_B = to_numpy(pred_B.data) hit_B += np.sum( discriminator_decision_B[:params.mini_batch_size] >= 0.5) hit_B += np.sum( discriminator_decision_B[params.mini_batch_size:] < 0.5) D_optimizer.step( ) # Only optimizes D's parameters; changes based on stored gradients from backward() # Clip weights _clip(self.D_X, params.clip_value) _clip(self.D_Y, params.clip_value) # print('D_loss',d_losses) sys.stdout.write( "[%d/%d] :: Discriminator Loss: %.3f \r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(d_losses)))) sys.stdout.flush() total += 2 * params.mini_batch_size * params.d_steps for g_index in range(params.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) AE_optimizer.zero_grad() self.D_X.eval() self.D_Y.eval() view_X, view_Y = self.get_batch_data_fast_new(en, it) # Generator X_AE ## adversarial loss Y_fake = self.X_AE.encode(view_X) # X_recon = self.X_AE.decode(X_Z) # Y_fake = self.Y_AE.encode(X_Z) pred_Y = self.D_Y(Y_fake) L_adv_X = self.loss_fn(pred_Y, label_G) X_Cycle = self.Y_AE.encode(Y_fake) L_Cycle_X = 1.0 - torch.mean( self.loss_fn2(view_X, X_Cycle)) # L_recon_X = 1.0 - torch.mean(self.loss_fn2(view_X, X_recon)) # L_G_AB = L_adv_X + params.recon_weight * L_recon_X # Generator Y_AE # adversarial loss X_fake = self.Y_AE.encode(view_Y) pred_X = self.D_X(X_fake) L_adv_Y = self.loss_fn(pred_X, label_G) ### Cycle Loss Y_Cycle = self.X_AE.encode(X_fake) L_Cycle_Y = 1.0 - torch.mean( self.loss_fn2(view_Y, Y_Cycle)) # L_recon_Y = 1.0 - torch.mean(self.loss_fn2(view_Y, Y_recon)) # L_G_BA = L_adv_Y + params.recon_weight * L_recon_Y # L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z)) # G_loss = L_G_AB + L_G_BA + L_Z G_loss = params.adv_weight * ( params.gate * L_adv_X + (1.0) * L_adv_Y) + \ params.cycle_weight * (L_Cycle_X+L_Cycle_Y) G_loss.backward() g_losses.append(to_numpy(G_loss.data)) G_AB_losses.append(to_numpy(L_adv_X.data)) G_BA_losses.append(to_numpy(L_adv_Y.data)) G_adv_losses.append(to_numpy(L_adv_Y.data)) G_AB_recon.append(to_numpy(L_Cycle_X.data)) G_BA_recon.append(to_numpy(L_Cycle_Y.data)) AE_optimizer.step() # Only optimizes G's parameters self.orthogonalize(self.X_AE.map1.weight.data) self.orthogonalize(self.Y_AE.map1.weight.data) sys.stdout.write( "[%d/%d] :: Generator Loss: %.3f \r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(g_losses)))) sys.stdout.flush() '''for each epoch''' D_A_acc_epochs.append(hit_A / total) D_B_acc_epochs.append(hit_B / total) G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses))) G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses))) D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses))) D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses))) G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon))) G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon))) # L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses))) print( "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins" .format(epoch, np.asscalar(np.mean(d_losses)), 0.5 * (hit_A + hit_B) / total, np.asscalar(np.mean(g_losses)), (timer() - start_time) / 60)) # lr decay # g_optim_state = AE_optimizer.state_dict() # old_lr = g_optim_state['param_groups'][0]['lr'] # g_optim_state['param_groups'][0]['lr'] = max(old_lr * params.lr_decay, params.lr_min) # AE_optimizer.load_state_dict(g_optim_state) # print("Changing the learning rate: {} -> {}".format(old_lr, g_optim_state['param_groups'][0]['lr'])) # d_optim_state = D_optimizer.state_dict() # d_optim_state['param_groups'][0]['lr'] = max( # d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min) # D_optimizer.load_state_dict(d_optim_state) # d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min) # D_optimizer.load_state_dict(d_optim_state) if (epoch + 1) % params.print_every == 0: # No need for discriminator weights # torch.save(d.state_dict(), 'discriminator_weights_en_es_{}.t7'.format(epoch)) # all_precisions = eval.get_all_precisions(G_AB(src_emb.weight).data) Vec_xy = self.X_AE.encode(Variable(en)) Vec_xyx = self.Y_AE.encode(Vec_xy) Vec_yx = self.Y_AE.encode(Variable(it)) Vec_yxy = self.X_AE.encode(Vec_yx) mstart_time = timer() # for method in ['csls_knn_10']: for method in [params.eval_method]: results = get_word_translation_accuracy( params.src_lang, src_word2id, Vec_xy.data, params.tgt_lang, tgt_word2id, it, method=method, dico_eval=self.eval_file, device=params.cuda_device) acc1 = results[0][1] results = get_word_translation_accuracy( params.tgt_lang, tgt_word2id, Vec_yx.data, params.src_lang, src_word2id, en, method=method, dico_eval=self.eval_file2, device=params.cuda_device) acc2 = results[0][1] print('{} takes {:.2f}s'.format( method, timer() - mstart_time)) print('Method:{} test_score:{:.4f}-{:.4f}'.format( method, acc1, acc2)) ''' # for method in ['csls_knn_10']: for method in [params.eval_method]: results = get_word_translation_accuracy( params.src_lang, src_word2id, Vec_xyx.data, params.src_lang, src_word2id, en, method=method, dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.src_lang,params.src_lang), device=params.cuda_device ) acc11 = results[0][1] # for method in ['csls_knn_10']: for method in [params.eval_method]: results = get_word_translation_accuracy( params.tgt_lang, tgt_word2id, Vec_yxy.data, params.tgt_lang, tgt_word2id, it, method=method, dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.tgt_lang,params.tgt_lang), device=params.cuda_device ) acc22 = results[0][1] print('Valid:{} score:{:.4f}-{:.4f}'.format(method, acc11, acc22)) avg_valid = (acc11+acc22)/2.0 # valid_x = torch.mean(self.loss_fn2(en, Vec_xyx.data)) # valid_y = torch.mean(self.loss_fn2(it, Vec_yxy.data)) # avg_valid = (valid_x+valid_y)/2.0 ''' # csls = 0 f_csls = eval.dist_mean_cosine(Vec_xy.data, it) b_csls = eval.dist_mean_cosine(Vec_yx.data, en) csls = (f_csls + b_csls) / 2.0 # csls = eval.calc_unsupervised_criterion(X_Z) if csls > best_valid_metric: print("New csls value: {}".format(csls)) best_valid_metric = csls fp = open( self.tune_dir + "/best/seed_{}_dico_{}_epoch_{}_acc_{:.3f}-{:.3f}.tmp" .format(seed, params.dico_build, epoch, acc1, acc2), 'w') fp.close() torch.save( self.X_AE.state_dict(), self.tune_dir + '/best/seed_{}_dico_{}_best_X.t7'.format( seed, params.dico_build)) torch.save( self.Y_AE.state_dict(), self.tune_dir + '/best/seed_{}_dico_{}_best_Y.t7'.format( seed, params.dico_build)) torch.save( self.D_X.state_dict(), self.tune_dir + '/best/seed_{}_dico_{}_best_Dx.t7'.format( seed, params.dico_build)) torch.save( self.D_Y.state_dict(), self.tune_dir + '/best/seed_{}_dico_{}_best_Dy.t7'.format( seed, params.dico_build)) # print(json.dumps(all_precisions)) # p_1 = all_precisions['validation']['adv']['without-ref']['nn'][1] # p_1 = all_precisions['validation']['adv']['without-ref']['csls'][1] # log_file.write(str(results) + "\n") # print('Method: nn score:{:.4f}'.format(acc)) # Saving generator weights # torch.save(X_AE.state_dict(), tune_dir+'/G_AB_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc)) # torch.save(Y_AE.state_dict(), tune_dir+'/G_BA_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc)) fp = open( self.tune_dir + "/seed_{}_epoch_{}_acc_{:.3f}-{:.3f}_valid_{:.4f}.tmp". format(seed, epoch, acc1, acc2, csls), 'w') fp.close() acc1_epochs.append(acc1) acc2_epochs.append(acc2) csls_epochs.append(csls) f_csls_epochs.append(f_csls) b_csls_epochs.append(b_csls) csls_fb, epoch_fb = max([ (score, index) for index, score in enumerate(csls_epochs) ]) fp = open( self.tune_dir + "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsfb".format( seed, epoch_fb, acc1_epochs[epoch_fb], acc2_epochs[epoch_fb], csls_fb), 'w') fp.close() csls_f, epoch_f = max([ (score, index) for index, score in enumerate(f_csls_epochs) ]) fp = open( self.tune_dir + "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsf".format( seed, epoch_f, acc1_epochs[epoch_f], acc2_epochs[epoch_f], csls_f), 'w') fp.close() csls_b, epoch_b = max([ (score, index) for index, score in enumerate(b_csls_epochs) ]) fp = open( self.tune_dir + "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsb".format( seed, epoch_b, acc1_epochs[epoch_b], acc2_epochs[epoch_b], csls_b), 'w') fp.close() ''' # Save the plot for discriminator accuracy and generator loss fig = plt.figure() plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A') plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B') plt.ylabel('D_accuracy') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A') plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B') plt.ylabel('D_losses') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB') plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA') plt.ylabel('G_losses') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB') plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA') plt.ylabel('G_Cycle_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_G_Cycle.png'.format(seed)) # fig = plt.figure() # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z') # plt.ylabel('L_Z_loss') # plt.xlabel('epochs') # plt.legend() # fig.savefig(tune_dir + '/seed_{}_stage_{}_L_Z.png'.format(seed,stage)) fig = plt.figure() plt.plot(range(0, len(acc1_epochs)), acc1_epochs, color='b', label='trans_acc1') plt.plot(range(0, len(acc2_epochs)), acc2_epochs, color='r', label='trans_acc2') plt.ylabel('trans_acc') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls') plt.plot(range(0, len(f_csls_epochs)), f_csls_epochs, color='r', label='csls_f') plt.plot(range(0, len(b_csls_epochs)), b_csls_epochs, color='g', label='csls_b') plt.ylabel('csls') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(g_losses)), g_losses, color='b', label='G_loss') plt.ylabel('g_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed)) fig = plt.figure() plt.plot(range(0, len(d_losses)), d_losses, color='b', label='csls') plt.ylabel('D_loss') plt.xlabel('epochs') plt.legend() fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed)) plt.close('all') ''' except KeyboardInterrupt: print("Interrupted.. saving model !!!") torch.save(self.X_AE.state_dict(), 'g_model_interrupt.t7') torch.save(self.D_X.state_dict(), 'd_model_interrupt.t7') log_file.close() exit() log_file.close() return self.X_AE def get_batch_data_fast_new(self, emb_en, emb_it): params = self.params random_en_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) random_it_indices = torch.LongTensor(params.mini_batch_size).random_( params.most_frequent_sampling_size) #print(random_en_indices) #print(random_it_indices) en_batch = to_variable(emb_en)[random_en_indices.cuda()] it_batch = to_variable(emb_it)[random_it_indices.cuda()] return en_batch, it_batch def export(self, src_dico, tgt_dico, emb_en, emb_it, seed, export_emb=False): params = _get_eval_params(self.params) eval = Evaluator(params, emb_en, emb_it, torch.cuda.is_available()) # Export adversarial dictionaries optim_X_AE = AE(params).cuda() optim_Y_AE = AE(params).cuda() print('Loading pre-trained models...') optim_X_AE.load_state_dict( torch.load(self.tune_dir + '/best/seed_{}_dico_{}_best_X.t7'.format( seed, params.dico_build))) optim_Y_AE.load_state_dict( torch.load(self.tune_dir + '/best/seed_{}_dico_{}_best_Y.t7'.format( seed, params.dico_build))) X_Z = optim_X_AE.encode(Variable(emb_en)).data Y_Z = optim_Y_AE.encode(Variable(emb_it)).data mstart_time = timer() for method in ['nn', 'csls_knn_10']: results = get_word_translation_accuracy(params.src_lang, src_dico[1], X_Z, params.tgt_lang, tgt_dico[1], emb_it, method=method, dico_eval=self.eval_file, device=params.cuda_device) acc1 = results[0][1] results = get_word_translation_accuracy(params.tgt_lang, tgt_dico[1], Y_Z, params.src_lang, src_dico[1], emb_en, method=method, dico_eval=self.eval_file2, device=params.cuda_device) acc2 = results[0][1] # csls = 0 print('{} takes {:.2f}s'.format(method, timer() - mstart_time)) print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2)) f_csls = eval.dist_mean_cosine(X_Z, emb_it) b_csls = eval.dist_mean_cosine(Y_Z, emb_en) csls = (f_csls + b_csls) / 2.0 print("Seed:{},ACC:{:.4f}-{:.4f},CSLS_FB:{:.6f}".format( seed, acc1, acc2, csls)) #''' print('Building dictionaries...') params.dico_build = "S2T&T2S" params.dico_method = "csls_knn_10" X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z) emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it) f_dico_induce = build_dictionary(X_Z, emb_it, params) f_dico_induce = f_dico_induce.cpu().numpy() Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z) emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en) b_dico_induce = build_dictionary(Y_Z, emb_en, params) b_dico_induce = b_dico_induce.cpu().numpy() f_dico_set = set([(a, b) for a, b in f_dico_induce]) b_dico_set = set([(b, a) for a, b in b_dico_induce]) intersect = list(f_dico_set & b_dico_set) union = list(f_dico_set | b_dico_set) with io.open( self.tune_dir + '/export/{}-{}.dict'.format(params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in f_dico_induce: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open( self.tune_dir + '/export/{}-{}.dict'.format(params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in b_dico_induce: f.write('{} {}\n'.format(tgt_dico[0][item[0]], src_dico[0][item[1]])) with io.open(self.tune_dir + '/export/{}-{}.intersect'.format( params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in intersect: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open(self.tune_dir + '/export/{}-{}.intersect'.format( params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in intersect: f.write('{} {}\n'.format(tgt_dico[0][item[1]], src_dico[0][item[0]])) with io.open( self.tune_dir + '/export/{}-{}.union'.format(params.src_lang, params.tgt_lang), 'w', encoding='utf-8', newline='\n') as f: for item in union: f.write('{} {}\n'.format(src_dico[0][item[0]], tgt_dico[0][item[1]])) with io.open( self.tune_dir + '/export/{}-{}.union'.format(params.tgt_lang, params.src_lang), 'w', encoding='utf-8', newline='\n') as f: for item in union: f.write('{} {}\n'.format(tgt_dico[0][item[1]], src_dico[0][item[0]])) if export_emb: print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang, params.src_lang)) loader.export_embeddings( src_dico[0], X_Z, path=self.tune_dir + '/export/{}-{}.{}'.format( params.src_lang, params.tgt_lang, params.src_lang), eformat='txt') print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang, params.tgt_lang)) loader.export_embeddings( tgt_dico[0], emb_it, path=self.tune_dir + '/export/{}-{}.{}'.format( params.src_lang, params.tgt_lang, params.tgt_lang), eformat='txt') print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang, params.tgt_lang)) loader.export_embeddings( tgt_dico[0], Y_Z, path=self.tune_dir + '/export/{}-{}.{}'.format( params.tgt_lang, params.src_lang, params.tgt_lang), eformat='txt') print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang, params.src_lang)) loader.export_embeddings( src_dico[0], emb_en, path=self.tune_dir + '/export/{}-{}.{}'.format( params.tgt_lang, params.src_lang, params.src_lang), eformat='txt')
model = Discriminator(args.length, len(species)).to(device) optimizer = optim.Adam(model.parameters(), lr=args.rate) # raise an error if receptive field is smaller than sampling length if args.length < receptive_field(model): raise Exception("Input sequences must be longer than {} bp.".format( receptive_field(model))) for epoch in range(args.epoch): train(model, device, loader, optimizer, epoch + 1) print("") # calculate style matrices if args.verbose > 1: print("Extracting style matrices...") model.eval() style_matrices = [] for record in SeqIO.parse(args.contig, "fasta"): tensor = to_tensor(str(record.seq)) style_matrices += model.get_style(tensor.float().to(device), args.layer) style_matrices = torch.cat(style_matrices, dim=0) torch.save(style_matrices, args.output) if args.verbose > 1: print("Genome style matrix is successfully written to {}.".format( args.output))
def train(): args=Config generator = Generator(args) discriminator = Discriminator(args) feat_extractor = get_feat_extractor() dataset = SRDataset(dataset_path=args['train_set_path'],hr_size=args['hr_size'], scale_factor=args['scale']) data_loader = torch.utils.data.DataLoader(dataset, batch_size=Config['batch_size'], shuffle=True) test_dataset = SRDataset(args['test_set_path'],args['hr_size'], args['scale']) test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True) if Config['optimizer']=='Adam': gen_optimizer = torch.optim.Adam(generator.parameters(), lr = Config['lr']) disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr = Config['lr']) else: gen_optimizer = torch.optim.SGD(generator.parameters(), lr = Config['lr']) disc_optimizer = torch.optim.SGD(discriminator.parameters(), lr = Config['lr']) if Config['tensorboard_log']: writer = SummaryWriter(Config['checkpoint_path']) for epoch in tqdm(range(Config['epochs'])): generator.train() discriminator.train() for lr, hr in data_loader: valid = torch.zeros((lr.shape[0],1), requires_grad=False) fake = torch.ones((lr.shape[0],1), requires_grad=False) # print(lr.shape) sr = generator(lr) d_fake = discriminator(sr) d_real = discriminator(hr) c_loss = content_loss(args, feat_extractor, hr, sr) adv_loss = 1e-3 * nn.BCELoss()(valid, d_fake) mse_loss = nn.MSELoss()(hr, sr) perceptual_loss = c_loss + adv_loss + mse_loss valid_loss = nn.BCELoss()(valid, d_real) fake_loss = nn.BCELoss()(fake, d_fake) d_loss = valid_loss + fake_loss perceptual_loss.backward() d_loss.backward() gen_optimizer.step() disc_optimizer.step() generator.eval() discriminator.eval() test_lr, test_hr = next(iter(test_data_loader)) with torch.set_grad_enabled(False): test_sr = generator(sr) for i in range(test_sr.shape[0]): img_sr = test_sr[i] img_hr = test_hr[i] img_lr = test_lr[i] save_image(img_sr, 'img_sr_%d.png'%i) save_image(img_hr, 'img_hr_%d.png'%i) save_image(img_lr, 'img_lr_%d.png'%i) print(f'Epoch {epoch}: Perceptual Loss:{perceptual_loss:.4f}, Disc Loss:{d_loss:.4f}') torch.save({'generator':generator, 'discriminator':discriminator}, os.path.join(Config['checkpoint_path'],'model.pth'))
# ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ # G_loss = criterion(fake_predict, torch.ones_like(fake_predict)) # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ # gen.zero_grad() G_loss.backward() gen_optim.step() # ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ # if batch_idx == 0: print(f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \ Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}") with torch.no_grad(): disc.eval() gen.eval() fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W) data = real.reshape(-1, CHANNELS, H, W) if BATCH_SIZE > 32: fake = fake[:32] data = data[:32] img_grid_fake = torchvision.utils.make_grid(fake, normalize=True) img_grid_real = torchvision.utils.make_grid(data, normalize=True) writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step) writer_real.add_image("Mnist Real Images",
class StarGAN(nn.Module): def __init__(self, config, train_loader, test_loader): super(StarGAN, self).__init__() self.config = config self.train_loader = train_loader self.test_loader = test_loader self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.test_source, self.test_domain, _ = next(iter(self.test_loader)) self.test_source = self.test_source.to(self.device) self.test_domain = self.test_domain.view(-1, 1, 1).to(self.device) self.test_batch_size, _, self.height, self.width = self.test_source.size( ) self.save_img_cnt = 0 self.loss = {} self.items = {} self.iter_size = len(self.train_loader) self.epoch_size = config['max_iter'] // self.iter_size + 1 lr = config['lr'] lr_F = config['lr_F'] beta1 = config['beta1'] beta2 = config['beta2'] init = config['init'] # weight_decay = config['weight_decay'] self.batch_size = config['batch_size'] self.gan_type = config['gan_type'] self.max_iter = config['max_iter'] self.img_size = config['crop_size'] self.path_sample = os.path.join('./results/', config['save_name'], "samples") self.path_model = os.path.join('./results/', config['save_name'], "models") self.w_style = config['w_style'] self.w_ds = config['w_ds'] self.w_cyc = config['w_cyc'] self.w_regul = config['w_regul'] self.num_domain = len(train_loader.dataset.domains) self.dim_style = config['dim_style'] self.dim_latent = config['mapping_network']['dim_latent'] self.generator = Generator(config['gen']) # 29072960 # self.generator = DummyModel(config['gen']) # 29072960 self.style_encoder = StyleEncoder(config['style_encoder'], self.num_domain, self.img_size) self.mapping_network = MappingNetwork(config['mapping_network'], self.num_domain, self.dim_style) self.discriminator = Discriminator(config['dis'], self.num_domain, self.img_size) self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr, (beta1, beta2)) params_g = list(self.generator.parameters()) + list( self.style_encoder.parameters()) self.optimizer_g = torch.optim.Adam(params_g, lr, (beta1, beta2)) self.optimizer_g.add_param_group({ 'params': self.mapping_network.parameters(), 'lr': lr_F, 'betas': (beta1, beta2), }) # self.scheduler_g = get_scheduler(self.optimizer_g, config) # self.scheduler_d = get_scheduler(self.optimizer_d, config) self.apply(weights_init(init)) self.criterion_l1 = nn.L1Loss() self.criterion_l2 = nn.MSELoss() self.criterion_bce = nn.BCEWithLogitsLoss() self.to(self.device) # def update_scheduler(self): # if self.current_epoch >= 10 and self.scheduler_d and self.scheduler_g: # self.scheduler_d.step() # self.scheduler_g.step() def calc_adversarial_loss(self, logit, is_real): if self.gan_type == 'bce': target_fn = torch.ones_like if is_real else torch.zeros_like loss = self.criterion_bce(logit, target_fn(logit)) elif self.gan_type == 'lsgan': target_fn = torch.ones_like if is_real else torch.zeros_like loss = self.criterion_l2(logit, target_fn(logit)) elif self.gan_type == 'wgan': if is_real: loss = -torch.mean(logit) else: loss = torch.mean(logit) else: raise NotImplementedError("Unsupported gan type: {}".format( self.gan_type)) return loss def calc_r1(self, real_images, logit_real): grad_real = autograd.grad(outputs=logit_real.sum(), inputs=real_images, create_graph=True)[0] grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1)**2).mean() grad_penalty = 0.5 * grad_penalty return grad_penalty def calc_gp(self, real_images, fake_images): # TODO : raise NotImplementedError("") alpha = torch.rand(real_images.size(0), 1, 1, 1).to(self.device) interpolated = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True) prob_interpolated, _ = self.discriminator(interpolated) grad_outputs = torch.ones(prob_interpolated.size()).to(self.device) gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0] gradients = gradients.reshape(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty def generate_random_nosie(self): random_noise = torch.randn(1, self.dim_latent).to(self.device) random_domain = torch.randint(self.num_domain, (self.batch_size, 1, 1)).to(self.device) return random_noise, random_domain def eval_mode_all(self): self.discriminator.eval() self.generator.eval() def update_d(self, real, real_domain, random_noise, random_domain): reset_gradients([self.optimizer_g, self.optimizer_d]) real.requires_grad = True style_mapped = self.mapping_network(random_noise, random_domain) fake = self.generator(real, style_mapped) # Adv logit_real = self.discriminator(real, real_domain) logit_fake = self.discriminator(fake.detach(), random_domain) adv_d_real = self.calc_adversarial_loss(logit_real, is_real=True) # .contiguous() adv_d_fake = self.calc_adversarial_loss(logit_fake, is_real=False) # .contiguous() if self.config['gan_type'] == 'bce': regul = self.calc_r1(real, logit_real) * self.w_regul elif self.config['gan_type'] == 'wgan': regul = self.calc_gp(real, fake) * self.w_regul self.adv_d_fake = adv_d_fake self.adv_d_real = adv_d_real loss_d = adv_d_fake + adv_d_real + regul loss_d.backward() self.optimizer_d.step() self.loss['adv_d_fake'] = adv_d_fake.item() self.loss['adv_d_real'] = adv_d_real.item() self.loss['regul'] = regul.item() self.items["logit_real"] = logit_real self.items["logit_fake_d"] = logit_fake def update_g(self, real, real_domain, random_noise, random_domain): reset_gradients([self.optimizer_g, self.optimizer_d]) style_fake = self.mapping_network(random_noise, random_domain) style_real = self.style_encoder(real, real_domain) fake = self.generator(real, style_fake) style_recon = self.style_encoder(fake, random_domain) image_recon = self.generator(fake, style_real) # Adversarial logit_fake = self.discriminator(fake, random_domain) adv_g = self.calc_adversarial_loss(logit_fake, is_real=True) # Style recon style_recon_loss = self.criterion_l1(style_fake, style_recon) * self.w_style # Style diversification random_noise1 = torch.randn(1, self.dim_latent).to(self.device) random_noise2 = torch.randn(1, self.dim_latent).to(self.device) random_domain1 = torch.randint(self.num_domain, (self.batch_size, 1, 1)).to(self.device) s1 = self.mapping_network(random_noise1, random_domain1) s2 = self.mapping_network(random_noise2, random_domain1) fake1 = self.generator(real, s1) fake2 = self.generator(real, s2) ds_loss = -self.criterion_l1(fake1, fake2) * self.w_ds # Cycle consistency cyc_loss = self.criterion_l1(real, image_recon) * self.w_cyc loss_g = adv_g + cyc_loss + style_recon_loss + ds_loss loss_g.backward() self.optimizer_g.step() self.loss['adv_g'] = adv_g.item() self.loss['style_recon_loss'] = style_recon_loss.item() self.loss['ds_loss'] = ds_loss.item() self.loss['cyc_loss'] = cyc_loss.item() self.items["real"] = real self.items["real_domain"] = real_domain self.items["random_noise"] = random_noise self.items["random_domain"] = random_domain self.items["random_noise1"] = random_noise1 self.items["random_noise2"] = random_noise2 self.items["random_domain1"] = random_domain1 self.items["logit_fake"] = logit_fake self.items["style_fake"] = style_fake self.items["style_real"] = style_real self.items["fake"] = fake self.items["recon"] = image_recon self.items["style_recon"] = style_recon def train_starGAN(self, init_epoch): d_step, g_step = self.config['d_step'], self.config['g_step'] log_iter = self.config['log_iter'] image_display_iter = self.config['image_display_iter'] image_save_iter = self.config['image_save_iter'] for epoch in range(init_epoch, self.epoch_size): self.current_epoch = epoch self.save_img_cnt = 0 for iters, (real, real_domain, _) in enumerate(self.train_loader): # self.update_scheduler() # real, real_domain = real.to(self.device), real_domain.view(-1, 1, 1).to(self.device) real, real_domain = real.to(self.device), real_domain.to( self.device) random_noise, random_domain = self.generate_random_nosie() if not iters & d_step: self.update_d(real, real_domain, random_noise, random_domain) if not iters % g_step: self.update_g(real, real_domain, random_noise, random_domain) if self.device.type == 'cuda': torch.cuda.synchronize() if not (iters + 1) % log_iter: self.print_log(epoch, iters) if not (iters + 1) % image_display_iter: show_batch_torch(torch.cat([ real, self.items['fake'].clamp(-1, 1), self.items['recon'].clamp(-1, 1) ]), n_rows=3, n_cols=-1) if not (iters + 1) % image_save_iter: self.test_sample = self.generate_test_samples( save=True) clear_jupyter_console() # TODO : arbitrary if epoch >= 10 and not (iters + 1) % 1000: print("w_ds decayed:", self.w_ds, " -> ", self.w_ds * 0.9) self.w_ds *= 0.9 # self.save_models(epoch) def print_log(self, epoch, iters): adv_d_real = self.loss['adv_d_real'] adv_d_fake = self.loss['adv_d_fake'] regul = self.loss['regul'] adv_g = self.loss['adv_g'] style_recon_loss = self.loss['style_recon_loss'] ds_loss = self.loss['ds_loss'] cyc_loss = self.loss['cyc_loss'] print( "[Epoch {}/{}, iters: {}/{}] " \ "- Adv: {:5.4} {:5.4} / {:5.4}, Style recon: {:5.4}, DS: {:5.4}, Cyc : {:5.4}, Regul : {:5.4}".format( epoch, self.epoch_size, iters + 1, self.iter_size, adv_d_real, adv_d_fake, adv_g, style_recon_loss, ds_loss, cyc_loss, regul ) ) def save_models(self, epoch): os.makedirs(self.path_model, exist_ok=True) state = { 'generator': self.generator.state_dict(), 'discriminator': self.discriminator.state_dict(), 'optimizer_d': self.optimizer_d.state_dict(), 'optimizer_g': self.optimizer_g.state_dict(), # 'scheduler_d': self.scheduler_d.state_dict(), # TODO # 'scheduler_g': self.scheduler_g.state_dict(), 'w_ds': self.w_ds, 'current_epoch': epoch, } save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch)) torch.save(state, save_name) def load_models(self, epoch=False): if not epoch: last_model_path = sorted( glob.glob(os.path.join(self.path_model, '*')))[-1] epoch = int(last_model_path.split('/')[-1].split('_')[1][:2]) save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch)) checkpoint = torch.load(save_name) # weight self.discriminator.load_state_dict(checkpoint['discriminator']) self.generator.load_state_dict(checkpoint['generator']) self.optimizer_d.load_state_dict(checkpoint['optimizer_d']) self.optimizer_g.load_state_dict(checkpoint['optimizer_g']) # self.scheduler_d.load_state_dict(checkpoint['scheduler_d']) # self.scheduler_g.load_state_dict(checkpoint['scheduler_g']) self.w_ds = checkpoint['w_ds'] self.current_epoch = checkpoint['current_epoch'] return epoch def resume_train(self, restart_epoch=False): restart_epoch = self.load_models(restart_epoch) print("Resume Training - Epoch: ", restart_epoch) self.train_starGAN(restart_epoch + 1) def generate_test_samples(self, save): os.makedirs(self.path_sample, exist_ok=True) with torch.no_grad(): reference, reference_domain, _ = next(iter(self.test_loader)) reference, reference_domain = reference.to( self.device), reference_domain.to(self.device) style_reference = self.style_encoder(reference, reference_domain) style_reference = style_reference.repeat(1, reference.size(0), 1).view( -1, 1, self.dim_style) source = self.test_source.repeat(reference.size(0), 1, 1, 1).view(-1, 3, self.height, self.width) generated = self.generator(source, style_reference).clamp(-1, 1) right_concat, _, _ = reshape_batch_torch( torch.cat([self.test_source, generated]), n_cols=self.test_batch_size, n_rows=-1) left_concat = torch.cat( [torch.zeros_like(reference[:1]), reference]) left_concat, _, _ = reshape_batch_torch(left_concat, n_cols=1, n_rows=-1) save_image = preprocess( np.concatenate([left_concat, right_concat], axis=1)) if save: save_name = os.path.join( self.path_sample, "{:02}_{:02}.jpg".format(self.current_epoch, self.save_img_cnt)) self.save_img_cnt += 1 plt.imsave(save_name, save_image) print("Test samples Saved:" + save_name) return save_image
class SAGAN: def __init__(self, args): self.args = args self.gen_model = Generator(args.channels, args.image_size, args.latent_dim, args.ngf) self.dis_model = Discriminator(args.channels, args.image_size, args.ndf) self.gen_opt = torch.optim.Adam(self.gen_model.parameters(), lr = args.gen_lr, betas = (args.beta1, args.beta2), weight_decay = args.weight_decay) self.dis_opt = torch.optim.Adam(self.dis_model.parameters(), lr = args.dis_lr, betas = (args.beta1, args.beta2), weight_decay = args.weight_decay) self.anime_dataset = AnimeDataset(args.base_image_path) self.train_loader = DataLoader(self.anime_dataset, batch_size = args.batch_size, shuffle = True, drop_last = False) def train_one_epoch(self, epoch): self.gen_model.train() self.dis_model.train() print('[INFO] Epoch:', epoch) pbar = tqdm(self.train_loader, total = len(self.train_loader)) acc_d_loss, acc_g_loss = 0, 0 for images in pbar: images = images.to(self.args.device).float() latents = torch.randn(self.args.batch_size, self.args.latent_dim, 1, 1, device = self.args.device).float() fake_images = self.gen_model(latents) dis_real = self.dis_model(images) dis_fake = self.dis_model(fake_images.detach()) self.dis_opt.zero_grad() dis_loss = dis_hinge_loss(dis_fake, dis_real) dis_loss.backward() self.dis_opt.step() latents = torch.randn(self.args.batch_size, self.args.latent_dim, 1, 1, device = self.args.device).float() fake_images = self.gen_model(latents) dis_fake = self.dis_model(fake_images) self.gen_opt.zero_grad() gen_loss = gen_hinge_loss(dis_fake) gen_loss.backward() self.gen_opt.step() acc_d_loss = acc_d_loss * self.args.loss_smooth + dis_loss.detach().cpu().item() * (1 - self.args.loss_smooth) acc_g_loss = acc_g_loss * self.args.loss_smooth * gen_loss.detach().cpu().item() * (1 - self.args.loss_smooth) def visualize(self, num_samples = 20): self.gen_model.eval() self.dis_model.eval() latents = torch.randn(num_samples, self.args.latent_dim, 1, 1) fake_images = self.gen_model(latents).detach().cpu().numpy() fake_images = fake_images * .5 + .5 fake_images = np.transpose(fake_images, (0, 2, 3, 1)) plt.figure(figsize = (10, 10)) for i in range(num_samples): plt.subplot(4, num_samples // 4, i + 1) plt.imshow(fake_images[i]) plt.savefig(str(time()) + '.jpg') def save_checkpoints(self, epoch): if not os.path.exists(self.args.checkpoints_path): os.mkdir(self.args.checkpoints_path) torch.save({ 'gen_model': self.gen_model.state_dict(), 'dis_model': self.dis_model.state_dict(), 'gen_opt': self.gen_opt.state_dict(), 'dis_opt': self.dis_model.state_dict() }, f'{self.args.checkpoints_path}/epoch_{epoch}_{time()}.tar') def train(self): for epoch in range(self.args.epochs): self.train_one_epoch(epoch + 1) self.visualize() if epoch == 0 or (epoch + 1) % self.args.checkpoint_step: self.save_checkpoints(epoch + 1)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--input_dir', help = 'Directory containing xxx_i_s and xxx_i_t with same prefix', default = cfg.example_data_dir) parser.add_argument('--save_dir', help = 'Directory to save result', default = cfg.predict_result_dir) parser.add_argument('--checkpoint', help = 'ckpt', default = cfg.ckpt_path) args = parser.parse_args() assert args.input_dir is not None assert args.save_dir is not None assert args.checkpoint is not None print_log('model compiling start.', content_color = PrintColor['yellow']) G = Generator(in_channels = 3).to(device) D1 = Discriminator(in_channels = 6).to(device) D2 = Discriminator(in_channels = 6).to(device) vgg_features = Vgg19().to(device) G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2)) checkpoint = torch.load(args.checkpoint) G.load_state_dict(checkpoint['generator']) D1.load_state_dict(checkpoint['discriminator1']) D2.load_state_dict(checkpoint['discriminator2']) G_solver.load_state_dict(checkpoint['g_optimizer']) D1_solver.load_state_dict(checkpoint['d1_optimizer']) D2_solver.load_state_dict(checkpoint['d2_optimizer']) trfms = To_tensor() example_data = example_dataset(data_dir= args.input_dir, transform = trfms) example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False) example_iter = iter(example_loader) print_log('Model compiled.', content_color = PrintColor['yellow']) print_log('Predicting', content_color = PrintColor['yellow']) G.eval() D1.eval() D2.eval() with torch.no_grad(): for step in tqdm(range(len(example_data))): try: inp = example_iter.next() except StopIteration: example_iter = iter(example_loader) inp = example_iter.next() i_t = inp[0].to(device) i_s = inp[1].to(device) name = str(inp[2][0]) o_sk, o_t, o_b, o_f = G(i_t, i_s, (i_t.shape[2], i_t.shape[3])) o_sk = o_sk.squeeze(0).detach().to('cpu') o_t = o_t.squeeze(0).detach().to('cpu') o_b = o_b.squeeze(0).detach().to('cpu') o_f = o_f.squeeze(0).detach().to('cpu') if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) o_sk = F.to_pil_image(o_sk) o_t = F.to_pil_image((o_t + 1)/2) o_b = F.to_pil_image((o_b + 1)/2) o_f = F.to_pil_image((o_f + 1)/2) o_f.save(os.path.join(args.save_dir, name + 'o_f.png'))
) * hr_img.size(0) torch.save(generator_net.state_dict(), weights_dir + 'G_epoch_%d.pth' % (epoch)) generator_losses.append( (epoch, generator_running_loss / len(train_set))) discriminator_losses.append( (epoch, discriminator_running_loss / len(train_set))) if epoch % 50 == 0: with torch.no_grad(): cur_epoch_dir = imgout_dir + str(epoch) + '/' os.makedirs(cur_epoch_dir, exist_ok=True) generator_net.eval() discriminator_net.eval() valid_bar = tqdm(validloader) img_count = 0 psnr_avg = 0.0 psnr = 0.0 for hr_img, lr_img in valid_bar: valid_bar.set_description('Img: %i PSNR: %f' % (img_count, psnr)) if torch.cuda.is_available(): lr_img = lr_img.cuda() hr_img = hr_img.cuda() sr_tensor = generator_net(lr_img) mse = torch.mean((hr_img - sr_tensor)**2) psnr = 10 * (torch.log10(1 / mse) + np.log10(4)) psnr_avg += psnr img_count += 1
class ModelBuilder(object): def __init__(self, use_cuda): self.cuda = use_cuda self._pre_data() self._build_model() self.i_mb = 0 def _pre_data(self): print('pre data...') self.data = Data(self.cuda) def _build_model(self): print('building model...') we = torch.load('./data/processed/we.pkl') self.i_encoder = CNN_Args_encoder(we) self.a_encoder = CNN_Args_encoder(we, need_kmaxavg=True) self.classifier = Classifier() self.discriminator = Discriminator() if self.cuda: self.i_encoder.cuda() self.a_encoder.cuda() self.classifier.cuda() self.discriminator.cuda() self.criterion_c = torch.nn.CrossEntropyLoss() self.criterion_d = torch.nn.BCELoss() para_filter = lambda model: filter(lambda p: p.requires_grad, model.parameters()) self.i_optimizer = torch.optim.Adagrad(para_filter(self.i_encoder), Config.lr, weight_decay=Config.l2_penalty) self.a_optimizer = torch.optim.Adagrad(para_filter(self.a_encoder), Config.lr, weight_decay=Config.l2_penalty) self.c_optimizer = torch.optim.Adagrad(self.classifier.parameters(), Config.lr, weight_decay=Config.l2_penalty) self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), Config.lr_d, weight_decay=Config.l2_penalty) def _print_train(self, epoch, time, loss, acc): print('-' * 80) print( '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |' .format(epoch, time, loss, acc * 100)) print('-' * 80) def _print_eval(self, task, loss, acc): print('| ' + task + ' loss {:10.5f} | acc {:5.2f}% |'.format(loss, acc * 100)) print('-' * 80) def _save_model(self, model, filename): torch.save(model.state_dict(), './weights/' + filename) def _load_model(self, model, filename): model.load_state_dict(torch.load('./weights/' + filename)) def _pretrain_i_one(self): self.i_encoder.train() self.classifier.train() total_loss = 0 correct_n = 0 for a1, a2i, a2a, sense in self.data.train_loader: if self.cuda: a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda( ), sense.cuda() a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable( a2a), Variable(sense) output = self.classifier(self.i_encoder(a1, a2i)) _, output_sense = torch.max(output, 1) assert output_sense.size() == sense.size() tmp = (output_sense == sense).long() correct_n += torch.sum(tmp).data loss = self.criterion_c(output, sense) self.i_optimizer.zero_grad() self.c_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(), Config.grad_clip) torch.nn.utils.clip_grad_norm(self.classifier.parameters(), Config.grad_clip) self.i_optimizer.step() self.c_optimizer.step() total_loss += loss.data * sense.size(0) return total_loss[0] / self.data.train_size, correct_n[ 0] / self.data.train_size def _pretrain_i_a_one(self): self.i_encoder.train() self.a_encoder.train() self.classifier.train() total_loss = 0 correct_n = 0 total_loss_a = 0 correct_n_a = 0 for a1, a2i, a2a, sense in self.data.train_loader: if self.cuda: a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda( ), sense.cuda() a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable( a2a), Variable(sense) # train i output = self.classifier(self.i_encoder(a1, a2i)) _, output_sense = torch.max(output, 1) assert output_sense.size() == sense.size() tmp = (output_sense == sense).long() correct_n += torch.sum(tmp).data loss = self.criterion_c(output, sense) self.i_optimizer.zero_grad() self.c_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(), Config.grad_clip) torch.nn.utils.clip_grad_norm(self.classifier.parameters(), Config.grad_clip) self.i_optimizer.step() self.c_optimizer.step() total_loss += loss.data * sense.size(0) #train a output = self.classifier(self.a_encoder(a1, a2a)) _, output_sense = torch.max(output, 1) assert output_sense.size() == sense.size() tmp = (output_sense == sense).long() correct_n_a += torch.sum(tmp).data loss = self.criterion_c(output, sense) self.a_optimizer.zero_grad() self.c_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.a_encoder.parameters(), Config.grad_clip) torch.nn.utils.clip_grad_norm(self.classifier.parameters(), Config.grad_clip) self.a_optimizer.step() self.c_optimizer.step() total_loss_a += loss.data * sense.size(0) return total_loss[0] / self.data.train_size, correct_n[ 0] / self.data.train_size, total_loss_a[ 0] / self.data.train_size, correct_n_a[0] / self.data.train_size def _adtrain_one(self, acc_d_for_train): total_loss = 0 total_loss_2 = 0 correct_n = 0 correct_n_d = 0 correct_n_d_for_train = 0 for a1, a2i, a2a, sense in self.data.train_loader: if self.cuda: a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda( ), sense.cuda() a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable( a2a), Variable(sense) # phase 1, train discriminator flag = 0 for k in range(Config.kd): # if self._test_d() != 1: if True: temp_d = 0 self.a_encoder.eval() self.i_encoder.eval() self.discriminator.train() self.d_optimizer.zero_grad() output_i = self.discriminator(self.i_encoder(a1, a2i)) temp_d += torch.sum((output_i < 0.5).long()).data # zero_tensor = torch.zeros(output_i.size()) zero_tensor = torch.Tensor(output_i.size()).random_( 0, 100) * 0.003 if self.cuda: zero_tensor = zero_tensor.cuda() zero_tensor = Variable(zero_tensor) d_loss_i = self.criterion_d(output_i, zero_tensor) d_loss_i.backward() output_a = self.discriminator(self.a_encoder(a1, a2a)) temp_d += torch.sum((output_a >= 0.5).long()).data # one_tensor = torch.ones(output_a.size()) # one_tensor = torch.Tensor(output_a.size()).fill_(Config.alpha) one_tensor = torch.Tensor(output_a.size()).random_( 0, 100) * 0.005 + 0.7 if self.cuda: one_tensor = one_tensor.cuda() one_tensor = Variable(one_tensor) d_loss_a = self.criterion_d(output_a, one_tensor) d_loss_a.backward() correct_n_d_for_train += temp_d temp_d = max(temp_d[0] / sense.size(0) / 2, acc_d_for_train) if temp_d < Config.thresh_high: torch.nn.utils.clip_grad_norm( self.discriminator.parameters(), Config.grad_clip) self.d_optimizer.step() # phase 2, train i/c self.i_encoder.train() self.classifier.train() self.discriminator.eval() self.i_optimizer.zero_grad() self.c_optimizer.zero_grad() sent_repr = self.i_encoder(a1, a2i) output = self.classifier(sent_repr) _, output_sense = torch.max(output, 1) assert output_sense.size() == sense.size() tmp = (output_sense == sense).long() correct_n += torch.sum(tmp).data loss_1 = self.criterion_c(output, sense) output_d = self.discriminator(sent_repr) correct_n_d += torch.sum((output_d < 0.5).long()).data one_tensor = torch.ones(output_d.size()) # one_tensor = torch.Tensor(output_d.size()).fill_(Config.alpha) # one_tensor = torch.Tensor(output_d.size()).random_(0,100) * 0.005 + 0.7 if self.cuda: one_tensor = one_tensor.cuda() one_tensor = Variable(one_tensor) loss_2 = self.criterion_d(output_d, one_tensor) loss = loss_1 + loss_2 * Config.lambda1 loss.backward() torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(), Config.grad_clip) torch.nn.utils.clip_grad_norm(self.classifier.parameters(), Config.grad_clip) self.i_optimizer.step() self.c_optimizer.step() total_loss += loss.data * sense.size(0) total_loss_2 += loss_2.data * sense.size(0) test_loss, test_acc = self._eval('test', 'i') self.logwriter.add_scalar('acc/test_acc_t_mb', test_acc * 100, self.i_mb) self.i_mb += 1 return total_loss[0] / self.data.train_size, correct_n[ 0] / self.data.train_size, correct_n_d[ 0] / self.data.train_size, total_loss_2[ 0] / self.data.train_size, correct_n_d_for_train[ 0] / self.data.train_size / 2 def _pretrain_i(self): best_test_acc = 0 for epoch in range(Config.pre_i_epochs): start_time = time.time() loss, acc = self._pretrain_i_one() self._print_train(epoch, time.time() - start_time, loss, acc) self.logwriter.add_scalar('loss/train_loss_i', loss, epoch) self.logwriter.add_scalar('acc/train_acc_i', acc * 100, epoch) dev_loss, dev_acc = self._eval('dev', 'i') self._print_eval('dev', dev_loss, dev_acc) self.logwriter.add_scalar('loss/dev_loss_i', dev_loss, epoch) self.logwriter.add_scalar('acc/dev_acc_i', dev_acc * 100, epoch) test_loss, test_acc = self._eval('test', 'i') self._print_eval('test', test_loss, test_acc) self.logwriter.add_scalar('loss/test_loss_i', test_loss, epoch) self.logwriter.add_scalar('acc/test_acc_i', test_acc * 100, epoch) if test_acc >= best_test_acc: best_test_acc = test_acc self._save_model(self.i_encoder, 'i.pkl') self._save_model(self.classifier, 'c.pkl') print('i_model saved at epoch {}'.format(epoch)) def _adjust_learning_rate(self, optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr def _train_together(self): best_test_acc = 0 loss = acc = loss_a = acc_a = 0 lr_t = Config.lr_t acc_d_for_train = 0 for epoch in range(Config.together_epochs): start_time = time.time() if epoch < Config.first_stage_epochs: loss, acc, loss_a, acc_a = self._pretrain_i_a_one() else: if epoch == Config.first_stage_epochs: self._adjust_learning_rate(self.i_optimizer, lr_t) self._adjust_learning_rate(self.c_optimizer, lr_t / 2) # elif (epoch - Config.first_stage_epochs) % 20 == 0: # lr_t *= 0.8 # self._adjust_learning_rate(self.i_optimizer, lr_t) # self._adjust_learning_rate(self.c_optimizer, lr_t) loss, acc, acc_d, loss_2, acc_d_for_train = self._adtrain_one( acc_d_for_train) self._print_train(epoch, time.time() - start_time, loss, acc) self.logwriter.add_scalar('loss/train_loss_t', loss, epoch) self.logwriter.add_scalar('acc/train_acc_t', acc * 100, epoch) self.logwriter.add_scalar('loss/train_loss_t_a', loss_a, epoch) self.logwriter.add_scalar('acc/train_acc_t_a', acc_a * 100, epoch) if epoch >= Config.first_stage_epochs: self.logwriter.add_scalar('acc/train_acc_d', acc_d * 100, epoch) self.logwriter.add_scalar('loss/train_loss_2', loss_2, epoch) self.logwriter.add_scalar('acc/acc_d_for_train', acc_d_for_train * 100, epoch) dev_loss, dev_acc = self._eval('dev', 'i') dev_loss_a, dev_acc_a = self._eval('dev', 'a') self._print_eval('dev', dev_loss, dev_acc) self.logwriter.add_scalar('loss/dev_loss_t', dev_loss, epoch) self.logwriter.add_scalar('acc/dev_acc_t', dev_acc * 100, epoch) self.logwriter.add_scalar('loss/dev_loss_t_a', dev_loss_a, epoch) self.logwriter.add_scalar('acc/dev_acc_t_a', dev_acc_a * 100, epoch) if epoch >= Config.first_stage_epochs: dev_acc_d = self._eval_d('dev') self.logwriter.add_scalar('acc/dev_acc_d', dev_acc_d * 100, epoch) test_loss, test_acc = self._eval('test', 'i') test_loss_a, test_acc_a = self._eval('test', 'a') self._print_eval('test', test_loss, test_acc) self.logwriter.add_scalar('loss/test_loss_t', test_loss, epoch) self.logwriter.add_scalar('acc/test_acc_t', test_acc * 100, epoch) self.logwriter.add_scalar('loss/test_loss_t_a', test_loss_a, epoch) self.logwriter.add_scalar('acc/test_acc_t_a', test_acc_a * 100, epoch) if epoch >= Config.first_stage_epochs: test_acc_d = self._eval_d('test') self.logwriter.add_scalar('acc/test_acc_d', test_acc_d * 100, epoch) if test_acc >= best_test_acc: best_test_acc = test_acc self._save_model(self.i_encoder, 't_i.pkl') self._save_model(self.classifier, 't_c.pkl') print('t_i t_c saved at epoch {}'.format(epoch)) def train(self, i_or_t): print('start training') self.logwriter = SummaryWriter(Config.logdir) if i_or_t == 'i': self._pretrain_i() elif i_or_t == 't': self._train_together() else: raise Exception('wrong i_or_t') print('training done') def _eval(self, task, i_or_a): self.i_encoder.eval() self.a_encoder.eval() self.classifier.eval() total_loss = 0 correct_n = 0 if task == 'dev': data = self.data.dev_loader n = self.data.dev_size elif task == 'test': data = self.data.test_loader n = self.data.test_size else: raise Exception('wrong eval task') for a1, a2i, a2a, sense1, sense2 in data: if self.cuda: a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda(), a2a.cuda( ), sense1.cuda(), sense2.cuda() a1 = Variable(a1, volatile=True) a2i = Variable(a2i, volatile=True) a2a = Variable(a2a, volatile=True) sense1 = Variable(sense1, volatile=True) sense2 = Variable(sense2, volatile=True) if i_or_a == 'i': output = self.classifier(self.i_encoder(a1, a2i)) elif i_or_a == 'a': output = self.classifier(self.a_encoder(a1, a2a)) else: raise Exception('wrong i_or_a') _, output_sense = torch.max(output, 1) assert output_sense.size() == sense1.size() gold_sense = sense1 mask = (output_sense == sense2) gold_sense[mask] = sense2[mask] tmp = (output_sense == gold_sense).long() correct_n += torch.sum(tmp).data loss = self.criterion_c(output, gold_sense) total_loss += loss.data * gold_sense.size(0) return total_loss[0] / n, correct_n[0] / n def _eval_d(self, task): self.i_encoder.eval() self.a_encoder.eval() self.classifier.eval() correct_n = 0 if task == 'train': n = self.data.train_size for a1, a2i, a2a, sense in self.data.train_loader: if self.cuda: a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda( ), sense.cuda() a1 = Variable(a1, volatile=True) a2i = Variable(a2i, volatile=True) a2a = Variable(a2a, volatile=True) sense = Variable(sense, volatile=True) output_i = self.discriminator(self.i_encoder(a1, a2i)) correct_n += torch.sum((output_i < 0.5).long()).data # output_a = self.discriminator(self.a_encoder(a1, a2a)) # correct_n += torch.sum((output_a >= 0.5).long()).data else: if task == 'dev': data = self.data.dev_loader n = self.data.dev_size elif task == 'test': data = self.data.test_loader n = self.data.test_size for a1, a2i, a2a, sense1, sense2 in data: if self.cuda: a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda( ), a2a.cuda(), sense1.cuda(), sense2.cuda() a1 = Variable(a1, volatile=True) a2i = Variable(a2i, volatile=True) a2a = Variable(a2a, volatile=True) sense1 = Variable(sense1, volatile=True) sense2 = Variable(sense2, volatile=True) output_i = self.discriminator(self.i_encoder(a1, a2i)) correct_n += torch.sum((output_i < 0.5).long()).data # output_a = self.discriminator(self.a_encoder(a1, a2a)) # correct_n += torch.sum((output_a >= 0.5).long()).data return correct_n[0] / n def _test_d(self): acc = self._eval_d('dev') phase = -100 if acc >= Config.thresh_high: phase = 1 elif acc > Config.thresh_low: phase = 0 else: phase = -1 return phase def eval(self, stage): if stage == 'i': self._load_model(self.i_encoder, 'i.pkl') self._load_model(self.classifier, 'c.pkl') test_loss, test_acc = self._eval('test', 'i') self._print_eval('test', test_loss, test_acc) elif stage == 't': self._load_model(self.i_encoder, 't_i.pkl') self._load_model(self.classifier, 't_c.pkl') test_loss, test_acc = self._eval('test', 'i') self._print_eval('test', test_loss, test_acc) else: raise Exception('wrong eval stage')
class Solver: def __init__(self, loader): self.loader = loader self.c_dim = 4 self.lambda_cls = 10.0 self.lambda_rec = 10.0 self.lambda_gp = 10.0 self.g_lr = 0.0001 self.d_lr = 0.0001 self.n_critic = 6 self.beta1 = 0.5 self.beta2 = 0.999 self.smooth_beta = 0.999 self.model_save_step = 1000 self.lr_update_step = 1000 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.image_size = 256 self.num_iters = 200000 self.num_iters_decay = 100000 self.log_step = 10 self.sample_step = 10 # Directories. self.log_dir = "log" self.sample_dir = "sample" self.model_save_dir = "model" self.result_dir = "result" # colors self.colors = until.colors self.void_classes = until.void_classes self.valid_classes = until.valid_classes self.class_names = until.class_names self.ignore_index = until.ignore_index self.n_classes = until.n_classes self.label_colours = dict(zip(range(19), self.colors)) self.class_map = dict(zip(self.valid_classes, range(19))) self.class_names = dict(zip(self.class_names, range(19))) print(self.class_names) self.build_model() def build_model(self): self.G = Generator(conv_dim=64, c_dim=self.c_dim) self.G_test = Generator(conv_dim=64, c_dim=self.c_dim) self.D = Discriminator(self.image_size, 64, self.c_dim) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) #self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), lr=self.g_lr, alpha=0.99, eps=1e-8) #self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), lr=self.d_lr, alpha=0.99, eps=1e-8) self.G.to(self.device) self.G_test.to(self.device) self.D.to(self.device) self.update_average(self.G_test, self.G, 0.) def eval_model(self): self.G.eval() self.G_test.eval() self.D.eval() def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" print('Loading the trained models from step {}...'.format(resume_iters)) G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) G_test_path = os.path.join(self.model_save_dir, '{}-G_test.ckpt'.format(resume_iters)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) self.G_test.load_state_dict(torch.load(G_test_path, map_location=lambda storage, loc: storage)) self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def denorm(self, x): """Convert the range from [-1, 1] to [0, 1].""" out = torch.flip(x, [1]) #out = (x + 1) / 2 return out.clamp_(0, 1) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" batch_size = labels.size(0) out = torch.zeros(batch_size, dim) out[np.arange(batch_size), labels.long()] = 1 return out def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) def classification_loss(self, logit, target): """Compute binary or softmax cross entropy loss.""" return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) def update_average(self, model_tgt, model_src, beta): toogle_grad(model_src, False) toogle_grad(model_tgt, False) param_dict_src = dict(model_src.named_parameters()) for p_name, p_tgt in model_tgt.named_parameters(): p_src = param_dict_src[p_name] assert(p_src is not p_tgt) p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) def get_zdist(self, dist_name, dim, device=None): # Get distribution if dist_name == 'uniform': low = -torch.ones(dim, device=device) high = torch.ones(dim, device=device) zdist = distributions.Uniform(low, high) elif dist_name == 'gauss': mu = torch.zeros(dim, device=device) scale = torch.ones(dim, device=device) zdist = distributions.Normal(mu, scale) else: raise NotImplementedError # Add dim attribute zdist.dim = dim return zdist def getBatch(self): try: x_real, label_org = next(self.data_iter) except: while True: try: self.data_iter = iter(self.loader) x_real, label_org = next(self.data_iter) break except: #a=0/0 pass return x_real, label_org def onehot(self, label): label = label.numpy() label_onehot = np.zeros((label.shape[0],self.n_classes,label.shape[1],label.shape[2])).astype(np.uint8) #print(label_onehot) for i in range(self.n_classes): label_onehot[:,i,:,:] = (label == i) #print(np.max(label_onehot)) label_onehot = torch.from_numpy(label_onehot) #print(label_onehot.shape) return label_onehot.to(self.device) def to_label(self, label_onehot): label = np.zeros((label_onehot.shape[0],1,label_onehot.shape[2],label_onehot.shape[3])).astype(np.uint8) label[:,0,:,:] = np.argmax(label_onehot, axis=1) label = torch.from_numpy(label) return label def vis(self, real, label_onehot): label = self.to_label(label_onehot) label = label.numpy() label_colors = np.zeros((label.shape[0],3,label.shape[2],label.shape[3])).astype(np.uint8) r = label.copy() g = label.copy() b = label.copy() for l in range(0, self.n_classes): r[label == l] = self.label_colours[l][0] g[label == l] = self.label_colours[l][1] b[label == l] = self.label_colours[l][2] r = np.reshape(r, ((label.shape[0], label.shape[2], label.shape[3]))) g = np.reshape(g, ((label.shape[0], label.shape[2], label.shape[3]))) b = np.reshape(b, ((label.shape[0], label.shape[2], label.shape[3]))) rgb = np.zeros((label.shape[0], 3, label.shape[2], label.shape[3])) rgb[:, 0, :, :] = r / 255.0 rgb[:, 1, :, :] = g / 255.0 rgb[:, 2, :, :] = b / 255.0 rgb = torch.from_numpy(rgb) save_image(rgb, "label.jpg", nrow=1, padding=0) save_image(self.denorm(real.data.cpu()), "real.jpg", nrow=1, padding=0) #print(label) # それぞれのラベルが何%を占めているか def label_contain_persent(self, label, index=None): #num_labels=255 label_per = np.zeros((label.shape[0],self.n_classes,1,1)).astype(np.float32) Ns = torch.sum(label<self.n_classes, (1,2), dtype=torch.float32) if index is None: for i in range(self.n_classes): label_per[:,i,0,0] = torch.sum(label==i, (1,2), dtype=torch.float32) / Ns else: for i in range(len(index)): #print(torch.sum(label==index[i], (1,2), dtype=torch.float32)) label_per[i,index[i],0,0] = torch.sum(label==index[i], (1,2), dtype=torch.float32)[i] / Ns[i] label_per = torch.from_numpy(label_per) #print(torch.sum(label_per, (1))) return label_per.view(label_per.size()[0], -1).to(self.device) def train(self, start_iter=0): g_lr = self.g_lr d_lr = self.d_lr zdist = None BCELoss = torch.nn.BCELoss() if start_iter > 0: self.restore_model(start_iter) # Start training. print('Start training...') start_time = time.time() for i in range(start_iter, self.num_iters): x_real, label = self.getBatch() label = label.clone() label_onehot = self.onehot(label) #print(c_org) # input images x_real = x_real.to(self.device) if zdist is None: zdist = self.get_zdist("uniform", (3,x_real.size(2),x_real.size(3)), device=self.device) # make noise noise = zdist.sample((x_real.size(0),)) if (i) % 1 == 0: # train discriminator toogle_grad(self.G, False) toogle_grad(self.D, True) #print(x_real.shape, label_org.shape) #print(x_real.shape) self.vis(x_real, label_onehot) # 隠したいカテゴリ hidden_categorys = [np.random.randint(self.n_classes) for _ in range(x_real.size()[0])] hidden_categorys = [self.class_names['car'] for _ in range(x_real.size()[0])] # onehotに変換 hidden_categorys_onehot = np.eye(self.n_classes, dtype=np.float32)[hidden_categorys] # one hot表現に変換 hidden_categorys_onehot = torch.from_numpy(hidden_categorys_onehot).to(self.device) #print(hidden_categorys_onehot) # 教師データにそれぞれ何割のラベルが付与されているか label_per_real = self.label_contain_persent(label) #print(label_per_real) # shape [batch, 19, 1, 1] out_src, out_cls_real = self.D(x_real) #label_real = torch.full((x_real.size(0),1), 1.0, device=self.device) d_loss_real = -torch.mean(out_src) # クラス割合loss d_loss_cls_real = naive_cross_entropy_loss(out_cls_real, label_per_real) #print(d_loss_cls_real) # shape 1 x_mask = self.G(x_real, hidden_categorys_onehot) x_fake = x_mask * x_real + (1.0-x_mask) * noise out_src_fake, out_cls_fake = self.D(x_fake.detach()) #label_fake = torch.full((x_real.size(0),1), 0.0, device=self.device) # クラス割合loss d_loss_cls_fake = naive_cross_entropy_loss(out_cls_fake, label_per_real) d_loss_fake = torch.mean(out_src_fake) # gp_loss alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) x_hat = (alpha * x_real.data + (1.0 - alpha) * x_fake.data).requires_grad_(True) out_src, _ = self.D(x_hat) d_loss_gp = self.gradient_penalty(out_src, x_hat) d_loss =d_loss_real + d_loss_fake + self.lambda_cls * (d_loss_cls_real+d_loss_cls_fake) + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_cls_real'] = d_loss_cls_real.item() loss['D/loss_cls_fake'] = d_loss_cls_fake.item() loss['D/loss_gp'] = d_loss_gp.item() # train generator if (i+1) % self.n_critic == 0: toogle_grad(self.G, True) toogle_grad(self.D, False) x_mask = self.G(x_real, hidden_categorys_onehot) x_fake = x_mask * x_real + (1.0-x_mask) * noise out_src, out_cls = self.D(x_fake) label_real = torch.full((x_real.size(0),1), 1.0, device=self.device) g_loss_fake = -torch.mean(out_src) g_loss_cls = self.classification_loss(-out_cls+1.0, hidden_categorys_onehot) #naive_cross_entropy_loss(-out_cls+1.0, label_per_real) # backward g_loss = g_loss_fake + self.lambda_cls * g_loss_cls self.reset_grad() g_loss.backward() self.g_optimizer.step() # smoothing self.update_average(self.G_test, self.G, self.smooth_beta) # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_cls'] = g_loss_cls.item() # Print out training information. if (i+1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # Translate fixed images for debugging. if (i+1) % self.sample_step == 0: x_fake_list = [x_real] x_fake_list.append(x_fake) #x_fake_list.append(x_reconst) x_concat = torch.cat(x_fake_list, dim=3) sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) # Save model checkpoints. if (i+1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) G_test_path = os.path.join(self.model_save_dir, '{}-G_test.ckpt'.format(i+1)) D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) torch.save(self.G.state_dict(), G_path) torch.save(self.G_test.state_dict(), G_test_path) torch.save(self.D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # Decay lr if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) def test(self, test_iters=None): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. if test_iters is not None: self.restore_model(test_iters) #self.eval_model() # Set data loader. data_loader = self.loader with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = [] for j in range(self.c_dim): c_trg = c_org.clone() c_trg[:,:] = 0.0 c_trg[:,j] = 1.0 c_trg_list.append(c_trg.to(self.device)) # Translate images. x_fake_list = [] for c_trg in c_trg_list: x_fake_list.append(self.G_test(x_real, c_trg)) print(x_fake_list[0]) # Save the translated images. try: x_concat = torch.cat(x_fake_list, dim=3) result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) print('Saved real and fake images into {}...'.format(result_path)) except: import traceback traceback.print_exc() print('Error {}...'.format(result_path))
loss_D_h2l_log.append(loss_D_h2l.item()) loss_D_l2h_log.append(loss_D_l2h.item()) loss_G_h2l_log.append(loss_G_h2l.item()) loss_G_l2h_log.append(loss_G_l2h.item()) loss_cycle_log.append(loss_cycle.item()) loss_all = (loss_D_h2l_log, loss_D_l2h_log, loss_G_h2l_log, loss_G_l2h_log, loss_cycle_log) file = open("loss_log.txt", "w") json.dump(loss_all, file) file.close() print("\n Testing and saving...") G_h2l.eval() D_h2l.eval() G_l2h.eval() D_l2h.eval() if ep % 10 == 0: for i, sample in enumerate(test_loader): if i >= num_test: break low_temp = sample["img16"].numpy() low = torch.from_numpy( np.ascontiguousarray(low_temp[:, ::-1, :, :])).cuda() with torch.no_grad(): hign_gen = G_l2h(low) np_low = low.cpu().numpy().transpose(0, 2, 3, 1).squeeze(0) np_gen = hign_gen.detach().cpu().numpy().transpose( 0, 2, 3, 1).squeeze(0)
real_loss = bce(Dreal, y_real) Dfake = D(fake) y_fake = torch.zeros_like(Dfake).to(device) fake_loss = bce(Dfake, y_fake) last_gen_loss = torch.mean(Dfake) total_loss = real_loss + fake_loss total_loss.backward() D_optim.step() D_losses.append(total_loss.item()) print("Epoch %s, G_loss: %f, D_loss: %f, time: %.3f, lr: %.3f" % (epoch, Gloss, total_loss, time.time() - train_t, lr)) if epoch % 10 == 0: with torch.no_grad(): D.eval() G.eval() plt.figure() z_ = sample_z(batch_size, z_dim).to(device).view(-1, z_dim, 1, 1) fake = G(z_).squeeze() fake = fake.cpu().numpy() fake = fake[0, :, :] fake[fake < 0] = 0 plt.imshow(fake) plt.colorbar() plt.savefig('./plots/' + str(epoch) + ".png") plt.close() if (epoch + 1) % 50 == 0: torch.save(G, save_path + 'G_epoch' + str(epoch)) torch.save(D, save_path + 'D_epoch' + str(epoch)) print('Saved Model')
def train(self, src_data, tgt_data): params = self.params print(params) penalty = 10.0 # penalty on cosine similarity print('Subword penalty {}'.format(penalty)) # Load data if not os.path.exists(params.data_dir): raise "Data path doesn't exists: %s" % params.data_dir src_lang = params.src_lang tgt_lang = params.tgt_lang self.suffix_str = src_lang + '_' + tgt_lang evaluator = Evaluator(params, src_data=src_data, tgt_data=tgt_data) monitor = Monitor(params, src_data=src_data, tgt_data=tgt_data) # Initialize subword embedding transformer # print('Initializing subword embedding transformer...') # src_data['F'].eval() # src_optimizer = optim.SGD(src_data['F'].parameters()) # for _ in trange(128): # indices = np.random.permutation(src_data['seqs'].size(0)) # indices = torch.LongTensor(indices) # if torch.cuda.is_available(): # indices = indices.cuda() # total_loss = 0 # for batch in indices.split(params.mini_batch_size): # src_optimizer.zero_grad() # vecs0 = src_data['vecs'][batch] # original # vecs = src_data['F'](src_data['seqs'][batch], src_data['E']) # loss = F.mse_loss(vecs0, vecs) # loss.backward() # total_loss += float(loss) # src_optimizer.step() # print('Done: final loss = {:.2f}'.format(total_loss)) src_optimizer = optim.SGD(src_data['F'].parameters(), lr=params.sw_learning_rate, momentum=0.9) print('Src optim: {}'.format(src_optimizer)) # Loss function loss_fn = torch.nn.BCELoss() # Create models g = Generator(input_size=params.g_input_size, hidden_size=params.g_hidden_size, output_size=params.g_output_size) if self.params.model_file: print('Load a model from ' + self.params.model_file) g.load(self.params.model_file) d = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size, hyperparams=get_hyperparams(params, disc=True)) seed = params.seed self.initialize_exp(seed) if not params.disable_cuda and torch.cuda.is_available(): print('Use GPU') # Move the network and the optimizer to the GPU g.cuda() d.cuda() loss_fn = loss_fn.cuda() if self.params.model_file is None: print('Initializing G based on distribution') # if the relative change of loss values is smaller than tol, stop iteration topn = 10000 tol = 1e-5 prev_loss, loss = None, None g_optimizer = optim.SGD(g.parameters(), lr=0.01, momentum=0.9) batches = src_data['seqs'][:topn].split(params.mini_batch_size) src_emb = torch.cat([ src_data['F'](batch, src_data['E']).detach() for batch in batches ]) tgt_emb = tgt_data['E'].emb.weight[:topn] if not params.disable_cuda and torch.cuda.is_available(): src_emb = src_emb.cuda() tgt_emb = tgt_emb.cuda() src_emb = F.normalize(src_emb) tgt_emb = F.normalize(tgt_emb) src_mean = src_emb.mean(dim=0).detach() tgt_mean = tgt_emb.mean(dim=0).detach() # src_std = src_emb.std(dim=0).deatch() # tgt_std = tgt_emb.std(dim=0).deatch() for _ in trange(1000): # at most 1000 iterations prev_loss = loss g_optimizer.zero_grad() mapped_src_mean = g(src_mean) loss = F.mse_loss(mapped_src_mean, tgt_mean) loss.backward() g_optimizer.step() # Orthogonalize self.orthogonalize(g.map1.weight.data) loss = float(loss) if type(prev_loss) is float and abs(prev_loss - loss) / prev_loss <= tol: break print('Done: final loss = {}'.format(float(loss))) evaluator.precision(g, src_data, tgt_data) sim = monitor.cosine_similarity(g, src_data, tgt_data) print('Cos sim.: {:3f} (+/-{:.3})'.format(sim.mean(), sim.std())) d_acc_epochs, g_loss_epochs = [], [] # Define optimizers d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate) g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate) for epoch in range(params.num_epochs): d_losses, g_losses = [], [] hit = 0 total = 0 start_time = timer() for mini_batch in range( 0, params.iters_in_epoch // params.mini_batch_size): for d_index in range(params.d_steps): d_optimizer.zero_grad() # Reset the gradients d.train() X, y, _ = self.get_batch_data(src_data, tgt_data, g) pred = d(X) d_loss = loss_fn(pred, y) d_loss.backward() d_optimizer.step() d_losses.append(d_loss.data.cpu().numpy()) discriminator_decision = pred.data.cpu().numpy() hit += np.sum( discriminator_decision[:params.mini_batch_size] >= 0.5) hit += np.sum( discriminator_decision[params.mini_batch_size:] < 0.5) sys.stdout.write("[%d/%d] :: Discriminator Loss: %f \r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(d_losses)))) sys.stdout.flush() total += 2 * params.mini_batch_size * params.d_steps for g_index in range(params.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) g_optimizer.zero_grad() src_optimizer.zero_grad() d.eval() X, y, src_vecs = self.get_batch_data(src_data, tgt_data, g) pred = d(X) g_loss = loss_fn(pred, 1 - y) src_loss = F.mse_loss(*src_vecs) if g_loss.is_cuda: src_loss = src_loss.cuda() loss = g_loss + penalty * src_loss loss.backward() g_optimizer.step() # Only optimizes G's parameters src_optimizer.step() g_losses.append(g_loss.data.cpu().numpy()) # Orthogonalize self.orthogonalize(g.map1.weight.data) sys.stdout.write( "[%d/%d] :: Generator Loss: %f \r" % (mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(g_losses)))) sys.stdout.flush() d_acc_epochs.append(hit / total) g_loss_epochs.append(np.asscalar(np.mean(g_losses))) print( "Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f}, Generator Loss: {:.5f}, Time elapsed {:.2f} mins" .format(epoch, np.asscalar(np.mean(d_losses)), hit / total, np.asscalar(np.mean(g_losses)), (timer() - start_time) / 60)) filename = path.join(params.model_dir, 'g_e{}.pth'.format(epoch)) print('Save a generator to ' + filename) g.save(filename) filename = path.join(params.model_dir, 's_e{}.pth'.format(epoch)) print('Save a subword transformer to ' + filename) src_data['F'].save(filename) if (epoch + 1) % params.print_every == 0: evaluator.precision(g, src_data, tgt_data) sim = monitor.cosine_similarity(g, src_data, tgt_data) print('Cos sim.: {:3f} (+/-{:.3})'.format( sim.mean(), sim.std())) return g
def training(opt): # ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ # EPOCHS = opt.epochs CHANNELS = 1 H, W = 64, 64 lr = opt.lr work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') FEATURE_D = 128 Z_DIM = 100 GEN_TRAIN_STEPS = 5 BATCH_SIZE = opt.batch_size if opt.logs: log_dir = Path(f'{opt.logs}').resolve() if log_dir.exists(): shutil.rmtree(str(log_dir)) if opt.weights: Weight_dir = Path(f'{opt.weights}').resolve() if not Weight_dir.exists(): Weight_dir.mkdir() # ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ # trans = transforms.Compose([ transforms.Resize((H, W)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) MNIST_data = MNIST('./data', True, transform=trans, download=True) loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1) # ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ # writer_fake = SummaryWriter(f"{str(log_dir)}/fake") writer_real = SummaryWriter(f"{str(log_dir)}/real") loss_writer = SummaryWriter(f"{str(log_dir)}/loss") # ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ # disc = Discriminator(img_channels=CHANNELS, feature_d=FEATURE_D).to(work_device) gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device) if opt.resume: if Path(Weight_dir / 'dirscriminator.pth').exists(): disc.load_state_dict( torch.load(str(Weight_dir / 'dirscriminator.pth'), map_location=work_device)) if Path(Weight_dir / 'generator.pth').exists(): gen.load_state_dict( torch.load(str(Weight_dir / 'generator.pth'), map_location=work_device)) # ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ # disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999)) gen_optim = optim.Adam(gen.parameters(), lr, (0.5, 0.999)) criterion = torch.nn.BCELoss() # ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ # D_loss_prev = math.inf G_loss_prev = math.inf D_loss = 0 G_loss = 0 for epoch in range(EPOCHS): for batch_idx, (real, _) in enumerate(tqdm(loader)): disc.train() gen.train() real = real.to(work_device) fixed_noise = torch.rand(real.shape[0], Z_DIM, 1, 1).to(work_device) # ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ # fake = gen(fixed_noise) # dim of (N,1,28,28) # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ # real_predict = disc(real).view(-1) # make it one dimensional array fake_predict = disc(fake).view(-1) # make it one dimensional array labels = torch.cat([ torch.ones_like(real_predict), torch.zeros_like(fake_predict) ], dim=0) # ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ # D_loss = criterion(torch.cat([real_predict, fake_predict], dim=0), labels) # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ # disc.zero_grad() D_loss.backward() disc_optim.step() # ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ # for _ in range(GEN_TRAIN_STEPS): # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ # fake = gen(fixed_noise) # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ # # make it one dimensional array fake_predict = disc(fake).view(-1) # ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ # G_loss = criterion(fake_predict, torch.ones_like(fake_predict)) # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ # gen.zero_grad() G_loss.backward() gen_optim.step() # ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ # if batch_idx == 0: print( f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \ Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}") with torch.no_grad(): disc.eval() gen.eval() fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W) data = real.reshape(-1, CHANNELS, H, W) if BATCH_SIZE > 32: fake = fake[:32] data = data[:32] img_grid_fake = torchvision.utils.make_grid(fake, normalize=True) img_grid_real = torchvision.utils.make_grid(data, normalize=True) writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=epoch) writer_real.add_image("Mnist Real Images", img_grid_real, global_step=epoch) loss_writer.add_scalar('discriminator', D_loss, global_step=epoch) loss_writer.add_scalar('generator', G_loss, global_step=epoch) # ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ # if opt.weights: if D_loss_prev > D_loss: D_loss_prev = D_loss weight_path = str(Weight_dir / 'dirscriminator.pth') torch.save(disc.state_dict(), weight_path) if G_loss_prev > G_loss: G_loss_prev = G_loss weight_path = str(Weight_dir / 'generator.pth') torch.save(gen.state_dict(), weight_path)
def train(self, src_emb, tgt_emb): params = self.params # Load data if not os.path.exists(params.data_dir): raise "Data path doesn't exists: %s" % params.data_dir en = src_emb it = tgt_emb self.params = _get_eval_params(params) params = self.params for _ in range(params.num_random_seeds): # Create models g = Generator(input_size=params.g_input_size, output_size=params.g_output_size) d = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size) print(d) lowest_loss = 1e5 g.apply(self.weights_init3) seed = random.randint(0, 1000) self.initialize_exp(seed) loss_fn = torch.nn.BCELoss() loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6) #d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate) #g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate) #d_optimizer = optim.Adam(d.parameters(), lr=params.d_learning_rate) d_optimizer = optim.RMSprop(d.parameters(), lr=params.d_learning_rate) g_optimizer = optim.Adam(g.parameters(), lr=params.g_learning_rate) if torch.cuda.is_available(): # Move the network and the optimizer to the GPU g = g.cuda() d = d.cuda() loss_fn = loss_fn.cuda() loss_fn2 = loss_fn2.cuda() # true_dict = get_true_dict(params.data_dir) d_acc_epochs = [] g_loss_epochs = [] d_loss_epochs = [] acc_all = [] d_losses = [] g_losses = [] csls_epochs = [] recon_losses = [] w_losses = [] try: for epoch in range(params.num_epochs): recon_losses = [] w_losses = [] start_time = timer() for mini_batch in range(0, params.iters_in_epoch // params.mini_batch_size): hit,total = 0,0 for d_index in range(params.d_steps): d_optimizer.zero_grad() # Reset the gradients d.train() # input, output = self.get_batch_data_fast(en, it, g, detach=True) src_batch, tgt_batch = self.get_batch_data_fast_new(en, it) fake,_ = g(src_batch) fake = fake.detach() real = tgt_batch # input = torch.cat([fake, real], 0) input = torch.cat([real, fake], 0) output = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_()) output[:params.mini_batch_size] = 1 - params.smoothing output[params.mini_batch_size:] = params.smoothing pred = d(input) d_loss = loss_fn(pred, output) d_loss.backward() # compute/store gradients, but don't change params d_losses.append(d_loss.data.cpu().numpy()) discriminator_decision = pred.data.cpu().numpy() hit += np.sum(discriminator_decision[:params.mini_batch_size] >= 0.5) hit += np.sum(discriminator_decision[params.mini_batch_size:] < 0.5) d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() # Clip weights _clip(d, params.clip_value) sys.stdout.write("[%d/%d] :: Discriminator Loss: %f \r" % ( mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(d_losses[-1000:])))) sys.stdout.flush() total += 2 * params.mini_batch_size * params.d_steps for g_index in range(params.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) g_optimizer.zero_grad() d.eval() src_batch, tgt_batch = self.get_batch_data_fast_new(en, it) fake, recon = g(src_batch) real = tgt_batch output = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_()) output[:params.mini_batch_size] = 1 - params.smoothing output[params.mini_batch_size:] = params.smoothing pred = d(fake) output2 = to_variable(torch.FloatTensor(params.mini_batch_size).zero_()) output2 = output2+1-params.smoothing recon_loss = 1.0 - torch.mean(loss_fn2(src_batch,recon)) g_loss = loss_fn(pred, output2) + params.recon_weight * recon_loss g_loss.backward() g_losses.append(g_loss.data.cpu().numpy()) recon_losses.append(recon_loss.data.cpu().numpy()) g_optimizer.step() # Only optimizes G's parameters #self.orthogonalize(g.map1.weight.data) sys.stdout.write("[%d/%d] :: Generator Loss: %f \r" % ( mini_batch, params.iters_in_epoch // params.mini_batch_size, np.asscalar(np.mean(g_losses[-1000:])))) sys.stdout.flush() acc_all.append(hit / total) if epoch > params.threshold: if lowest_loss > float(g_loss.data): lowest_loss = float(g_loss.data) W = g.map1.weight.data.cpu().numpy() w_losses.append(np.linalg.norm(np.dot(W.T, W) - np.identity(params.g_input_size))) X_Z = g(src_emb.weight)[0].data Y_Z = tgt_emb.weight.data mstart_time = timer() for method in [params.dico_method]: results = get_word_translation_accuracy( 'en', self.src_ids, X_Z, 'zh', self.tgt_ids, Y_Z, method=method, path = params.data_dir+params.validation_file ) acc = results[0][1] #print('{} takes {:.2f}s'.format(method, timer() - mstart_time)) #print('Method:{} score:{:.4f}'.format(method,acc)) torch.save(g.state_dict(), 'tune/best/G_seed{}_epoch_{}_batch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,epoch,mini_batch,params.most_frequent_sampling_size,acc)) ''' if mini_batch % 500==0: #d_acc_epochs.append(hit / total) #d_loss_epochs.append(np.asscalar(np.mean(d_losses))) #g_loss_epochs.append(np.asscalar(np.mean(g_losses))) if epoch > params.threshold: W = g.map1.weight.data.cpu().numpy() w_loss = np.linalg.norm(np.dot(W.T, W) - np.identity(params.g_input_size)) #print("D_acc:{:.3f} d_loss:{:.3f} g_loss:{:.3f} w_loss:{:.2f} ".format(hit / total,np.asscalar(np.mean(d_losses)),np.asscalar(np.mean(g_losses)),w_loss)) #print("D_acc:{:.3f} d_loss:{:.3f} g_loss:{:.3f} w_loss:{:.2f}".format(hit / total,d_loss.data[0],g_loss.data[0],w_loss)) X_Z = g(src_emb.weight)[0].data Y_Z = tgt_emb.weight.data mstart_time = timer() for method in [params.dico_method]: results = get_word_translation_accuracy( 'en', self.src_ids, X_Z, 'zh', self.tgt_ids, Y_Z, method=method, path = params.validation_file ) acc = results[0][1] #print('epoch:{} Method:{} score:{:.4f}'.format(mini_batch,method,acc)) torch.save(g.state_dict(), 'tune/thu/G_seed{}_epoch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,mini_batch,params.most_frequent_sampling_size,acc)) ''' X_Z = g(src_emb.weight)[0].data Y_Z = tgt_emb.weight.data print("Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f},Generator Loss: {:.5f}, Time elapsed {:.2f}mins".format(epoch,np.asscalar(np.mean(d_losses[-1562:])), hit / total,np.asscalar(np.mean(g_losses[-1562:])),(timer() - start_time) / 60)) mstart_time = timer() for method in [params.dico_method]: results = get_word_translation_accuracy( 'en', self.src_ids, X_Z, 'zh', self.tgt_ids, Y_Z, method=method, path = params.data_dir+params.validation_file ) acc = results[0][1] print('epoch:{} Method:{} score:{:.4f}'.format(epoch,method,acc)) torch.save(g.state_dict(), 'tune/G_seed{}_epoch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,epoch,params.most_frequent_sampling_size,acc)) # Save the plot for discriminator accuracy and generator loss fig = plt.figure() plt.plot(range(0, len(acc_all)), acc_all, color='b', label='discriminator') plt.ylabel('D_accuracy_all') plt.xlabel('epochs') plt.legend() fig.savefig('tune/D_acc_all.png') fig = plt.figure() plt.plot(range(0, len(d_losses)), d_losses, color='b', label='discriminator') plt.ylabel('D_loss_all') plt.xlabel('epochs') plt.legend() fig.savefig('tune/D_loss_all.png') fig = plt.figure() plt.plot(range(0, len(g_losses)), g_losses, color='b', label='discriminator') plt.ylabel('G_loss_all') plt.xlabel('epochs') plt.legend() fig.savefig('tune/G_loss_all.png') fig = plt.figure() plt.plot(range(0, len(w_losses)), w_losses, color='b', label='discriminator') plt.ylabel('||W^T*W - I||') plt.xlabel('epochs') plt.legend() fig.savefig('tune/W^TW.png') plt.close('all') except KeyboardInterrupt: print("Interrupted.. saving model !!!") torch.save(g.state_dict(), 'tune/g_model_interrupt.t7') torch.save(d.state_dict(), 'tune/d_model_interrupt.t7') exit() return g
class FNM(object): def __init__(self, args): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id self.batch_size = args.batch_size self.lr = args.lr self.profile_list_path = args.profile_list self.front_list_path = args.front_list self.profile_path = args.profile_path self.front_path = args.front_path self.test_path = args.test_path self.test_list = args.test_list self.crop_size = args.ori_height self.image_size = args.height self.res_n = args.res_n self.is_finetune = args.is_finetune self.result_name = args.result_name self.summary_dir = args.summary_dir self.iteration = args.iteration self.weight_decay = args.weight_decay self.decay_flag = args.decay_flag self.print_freq = args.print_freq self.save_freq = args.save_freq self.img_size = args.width self.model_name = args.model_name # For hyper parameters self.lambda_l1 = args.lambda_l1 self.lambda_fea = args.lambda_fea self.lambda_reg = args.lambda_reg self.lambda_gan = args.lambda_gan self.lambda_gp = args.lambda_gp self.channel = args.channel self.device = torch.device("cuda:{}".format(args.device_id)) self.make_dirs() self.build_model() """Define Loss""" self.L1_loss = nn.L1Loss().to(self.device) self.L2_loss = nn.MSELoss().to(self.device) def make_dirs(self): check_folder(self.summary_dir) check_folder(os.path.join("results", self.result_name, "model")) check_folder(os.path.join("results", self.result_name, "img")) def build_model(self): self.expert_net = se50_net( "./other_models/arcface_se50/model_ir_se50.pth").to(self.device) for param in self.expert_net.parameters(): param.requires_grad = False #self.dataset = sample_dataset(self.profile_list_path, self.front_list_path, self.profile_path, self.front_path, self.crop_size, self.image_size) self.front_loader = get_loader(self.front_list_path, self.front_path, self.crop_size, self.image_size, self.batch_size, mode="train", num_workers=8) self.profile_loader = get_loader(self.profile_list_path, self.profile_path, self.crop_size, self.image_size, self.batch_size, mode="train", num_workers=8) self.test_loader = get_loader(self.test_list, self.test_path, self.crop_size, self.image_size, self.batch_size, mode="test", num_workers=8) #self.front_loader = iter(self.front_loader) #self.profile_loader = iter(self.profile_loader) #resnet_blocks resnet_block_list = [] for i in range(self.res_n): resnet_block_list.append(ResnetBlock(512, use_bias=False)) self.body = nn.Sequential(*resnet_block_list).to(self.device) #[b, 512, 7, 7] self.decoder = Decoder().to(self.device) self.dis = Discriminator(self.channel).to(self.device) self.G_optim = torch.optim.Adam(itertools.chain( self.body.parameters(), self.decoder.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) self.D_optim = torch.optim.Adam(itertools.chain(self.dis.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) self.downsample112x112 = nn.Upsample(size=(112, 112), mode='bilinear') def update_lr(self, start_iter): if self.decay_flag and start_iter > (self.iteration // 2): self.G_optim.param_groups[0]['lr'] -= ( self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) self.D_optim.param_groups[0]['lr'] -= ( self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) def train(self): self.body.train(), self.decoder.train(), self.dis.train() start_iter = 1 if self.is_finetune: model_list = glob( os.path.join("results", self.result_name, "model", "*.pt")) if not len(model_list) == 0: model_list.sort() start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join("results", self.result_name, 'model'), start_iter) print(" [*] Load SUCCESS") self.update_lr(start_iter) print("training start...") start_time = time.time() for step in range(start_iter, self.iteration + 1): self.update_lr(start_iter) try: front_224, front_112 = front_iter.next() if front_224.shape[0] != self.batch_size: raise Exception except: front_iter = iter(self.front_loader) front_224, front_112 = front_iter.next() try: profile_224, profile_112 = profile_iter.next() if profile_224.shape[0] != self.batch_size: raise Exception except: profile_iter = iter(self.profile_loader) profile_224, profile_112 = profile_iter.next() profile_224, front_224, profile_112, front_112 = profile_224.to( self.device), front_224.to(self.device), profile_112.to( self.device), front_112.to(self.device) # Update D self.D_optim.zero_grad() feature_p = self.expert_net.get_feature(profile_112) feature_f = self.expert_net.get_feature(front_112) gen_p = self.decoder(self.body(feature_p)) gen_f = self.decoder(self.body(feature_f)) feature_gen_p = self.expert_net.get_feature( self.downsample112x112(gen_p)) feature_gen_f = self.expert_net.get_feature( self.downsample112x112(gen_f)) d_f = self.dis(front_224) d_gen_p = self.dis(gen_p) d_gen_f = self.dis(gen_f) D_adv_loss = torch.mean( tensor_tuple_sum(d_gen_f) * 0.5 + tensor_tuple_sum(d_gen_p) * 0.5 - tensor_tuple_sum(d_f)) / 5 alpha = torch.rand(gen_p.size(0), 1, 1, 1).to(self.device) inter = (alpha * front_224.data + (1 - alpha) * gen_p.data).requires_grad_(True) out_inter = self.dis(inter) gradient_penalty_loss = ( gradient_penalty(out_inter[0], inter, self.device) + gradient_penalty(out_inter[1], inter, self.device) + gradient_penalty(out_inter[2], inter, self.device) + gradient_penalty(out_inter[3], inter, self.device) + gradient_penalty(out_inter[4], inter, self.device)) / 5 #print("gradient_penalty_loss:{}".format(gradient_penalty_loss)) d_loss = self.lambda_gan * D_adv_loss + self.lambda_gp * gradient_penalty_loss d_loss.backward(retain_graph=True) self.D_optim.step() # Update G self.G_optim.zero_grad() try: front_224, front_112 = front_iter.next() if front_224.shape[0] != self.batch_size: raise Exception except: front_iter = iter(self.front_loader) front_224, front_112 = front_iter.next() try: profile_224, profile_112 = profile_iter.next() if profile_224.shape[0] != self.batch_size: raise Exception except: profile_iter = iter(self.profile_loader) profile_224, profile_112 = profile_iter.next() profile_224, front_224, profile_112, front_112 = profile_224.to( self.device), front_224.to(self.device), profile_112.to( self.device), front_112.to(self.device) feature_p = self.expert_net.get_feature(profile_112) feature_f = self.expert_net.get_feature(front_112) gen_p = self.decoder(self.body(feature_p)) gen_f = self.decoder(self.body(feature_f)) feature_gen_p = self.expert_net.get_feature( self.downsample112x112(gen_p)) feature_gen_f = self.expert_net.get_feature( self.downsample112x112(gen_f)) d_f = self.dis(front_224) d_gen_p = self.dis(gen_p) d_gen_f = self.dis(gen_f) pixel_loss = torch.mean(self.L1_loss(front_224, gen_f)) feature_p_norm = l2_norm(feature_p) feature_f_norm = l2_norm(feature_f) feature_gen_p_norm = l2_norm(feature_gen_p) feature_gen_f_norm = l2_norm(feature_gen_f) perceptual_loss = torch.mean( 0.5 * (1 - torch.sum(torch.mul(feature_p_norm, feature_gen_p_norm), dim=(1, 2, 3))) + 0.5 * (1 - torch.sum(torch.mul(feature_f_norm, feature_gen_f_norm), dim=(1, 2, 3)))) G_adv_loss = -torch.mean( tensor_tuple_sum(d_gen_f) * 0.5 + tensor_tuple_sum(d_gen_p) * 0.5) / 5 g_loss = self.lambda_gan * G_adv_loss + self.lambda_l1 * pixel_loss + self.lambda_fea * perceptual_loss g_loss.backward() self.G_optim.step() print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, d_loss, g_loss)) print("D_adv_loss : %.8f" % (self.lambda_gan * D_adv_loss)) print("G_adv_loss : %.8f" % (self.lambda_gan * G_adv_loss)) print("pixel_loss : %.8f" % (self.lambda_l1 * pixel_loss)) print("perceptual_loss : %.8f" % (self.lambda_fea * perceptual_loss)) print("gp_loss : %.8f" % (self.lambda_gp * gradient_penalty_loss)) with torch.no_grad(): if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 4, 0, 3)) self.body.eval(), self.decoder.eval(), self.dis.eval() for _ in range(train_sample_num): try: front_224, front_112 = front_iter.next() if front_224.shape[0] != self.batch_size: raise Exception except: front_iter = iter(self.front_loader) front_224, front_112 = front_iter.next() try: profile_224, profile_112 = profile_iter.next() if profile_224.shape[0] != self.batch_size: raise Exception except: profile_iter = iter(self.profile_loader) profile_224, profile_112 = profile_iter.next() profile_224, front_224, profile_112, front_112 = profile_224.to( self.device), front_224.to( self.device), profile_112.to( self.device), front_112.to(self.device) feature_p = self.expert_net.get_feature(profile_112) feature_f = self.expert_net.get_feature(front_112) gen_p = self.decoder(self.body(feature_p)) gen_f = self.decoder(self.body(feature_f)) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm( profile_224[0]))), RGB2BGR(tensor2numpy(denorm(gen_p[0]))), RGB2BGR(tensor2numpy(denorm(front_224[0]))), RGB2BGR(tensor2numpy(denorm(gen_f[0])))), 0)), 1) for _ in range(train_sample_num): show_list = [] for i in range(2): try: test_profile_224, test_profile_112 = test_iter.next( ) if test_profile_224.shape[0] != self.batch_size: raise Exception except: test_iter = iter(self.test_loader) test_profile_224, test_profile_112 = test_iter.next( ) test_profile_224, test_profile_112 = test_profile_224.to( self.device), test_profile_112.to(self.device) test_feature_p = self.expert_net.get_feature( test_profile_112) test_gen_p = self.decoder( self.body(test_feature_p)) show_list.append(test_profile_224[0]) show_list.append(test_gen_p[0]) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(show_list[0]))), RGB2BGR(tensor2numpy(denorm(show_list[1]))), RGB2BGR(tensor2numpy(denorm(show_list[2]))), RGB2BGR(tensor2numpy(denorm(show_list[3])))), 0)), 1) cv2.imwrite( os.path.join("results", self.result_name, 'img', 'A2B_%07d.png' % step), A2B * 255.0) self.body.train(), self.decoder.train(), self.dis.train() if step % self.save_freq == 0: self.save( os.path.join("results", self.result_name, "model"), step) if step % 1000 == 0: params = {} params['body'] = self.body.state_dict() params['decoder'] = self.decoder.state_dict() params['dis'] = self.dis.state_dict() torch.save( params, os.path.join("results", self.result_name, self.model_name + "_params_latest.pt")) def load(self, dir, step): params = torch.load( os.path.join(dir, self.model_name + '_params_%07d.pt' % step)) self.body.load_state_dict(params['body']) self.decoder.load_state_dict(params['decoder']) self.dis.load_state_dict(params['dis']) def save(self, dir, step): params = {} params['body'] = self.body.state_dict() params['decoder'] = self.decoder.state_dict() params['dis'] = self.dis.state_dict() torch.save( params, os.path.join(dir, self.model_name + '_params_%07d.pt' % step)) def demo(self): try: front_224, front_112 = front_iter.next() if front_224.shape[0] != self.batch_size: raise Exception except: front_iter = iter(self.front_loader) front_224, front_112 = front_iter.next() try: profile_224, profile_112 = profile_iter.next() if profile_224.shape[0] != self.batch_size: raise Exception except: profile_iter = iter(self.profile_loader) profile_224, profile_112 = profile_iter.next() profile_224, front_224, profile_112, front_112 = profile_224.to( self.device), front_224.to(self.device), profile_112.to( self.device), front_112.to(self.device) D_face, D_eye, D_nose, D_mouth, D_map = self.dis(profile_224) ''' print("D_face.shape:", D_face.shape) print("D_eye.shape:", D_eye.shape) print("D_nose.shape:", D_nose.shape) print("D_mouth.shape:", D_mouth.shape) ''' cv2.imwrite("profile.jpg", cv2.cvtColor(tensor2im(profile_112), cv2.COLOR_BGR2RGB)) cv2.imwrite("front.jpg", cv2.cvtColor(tensor2im(front_112), cv2.COLOR_BGR2RGB)) feature = self.expert_net.get_feature(profile_224) print(feature.shape) '''
class Solver(object): def __init__(self, config, data_loader): self.generator = None self.discriminator = None self.g_optimizer = None self.d_optimizer = None self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.z_dim = config.z_dim self.beta1 = config.beta1 self.beta2 = config.beta2 self.image_size = config.image_size self.data_loader = data_loader self.num_epochs = config.num_epochs self.batch_size = config.batch_size self.sample_size = config.sample_size self.lr = config.lr self.log_step = config.log_step self.sample_step = config.sample_step self.sample_path = config.sample_path self.model_path = config.model_path self.build_model() def build_model(self): """Build generator and discriminator.""" self.generator = Generator(z_dim=self.z_dim, image_size=self.image_size, conv_dim=self.g_conv_dim) self.discriminator = Discriminator(image_size=self.image_size, conv_dim=self.d_conv_dim) self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr, [self.beta1, self.beta2]) self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.generator.cuda() self.discriminator.cuda() def to_variable(self, x): """Convert tensor to variable.""" if torch.cuda.is_available(): x = x.cuda() return Variable(x) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def reset_grad(self): """Zero the gradient buffers.""" self.discriminator.zero_grad() self.generator.zero_grad() def denorm(self, x): """Convert range (-1, 1) to (0, 1)""" out = (x + 1) / 2 return out.clamp(0, 1) def train(self): """Train generator and discriminator.""" fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim)) total_step = len(self.data_loader) for epoch in range(self.num_epochs): for i, images in enumerate(self.data_loader): # ===================== Train D =====================# images = self.to_variable(images) batch_size = images.size(0) noise = self.to_variable(torch.randn(batch_size, self.z_dim)) # Train D to recognize real images as real. outputs = self.discriminator(images) real_loss = torch.mean( (outputs - 1)**2 ) # L2 loss instead of Binary cross entropy loss (this is optional for stable training) # Train D to recognize fake images as fake. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) fake_loss = torch.mean(outputs**2) # Backprop + optimize d_loss = real_loss + fake_loss self.reset_grad() d_loss.backward() self.d_optimizer.step() # ===================== Train G =====================# noise = self.to_variable(torch.randn(batch_size, self.z_dim)) # Train G so that D recognizes G(z) as real. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) g_loss = torch.mean((outputs - 1)**2) # Backprop + optimize self.reset_grad() g_loss.backward() self.g_optimizer.step() # print the log info if (i + 1) % self.log_step == 0: print( 'Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, ' 'd_fake_loss: %.4f, g_loss: %.4f' % (epoch + 1, self.num_epochs, i + 1, total_step, real_loss.data[0], fake_loss.data[0], g_loss.data[0])) # save the sampled images if (i + 1) % self.sample_step == 0: fake_images = self.generator(fixed_noise) torchvision.utils.save_image( self.denorm(fake_images.data), os.path.join( self.sample_path, 'fake_samples-%d-%d.png' % (epoch + 1, i + 1))) # save the model parameters for each epoch g_path = os.path.join(self.model_path, 'generator-%d.pkl' % (epoch + 1)) d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % (epoch + 1)) torch.save(self.generator.state_dict(), g_path) torch.save(self.discriminator.state_dict(), d_path) def sample(self): # Load trained parameters g_path = os.path.join(self.model_path, 'generator-%d.pkl' % (self.num_epochs)) d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % (self.num_epochs)) self.generator.load_state_dict(torch.load(g_path)) self.discriminator.load_state_dict(torch.load(d_path)) self.generator.eval() self.discriminator.eval() # Sample the images noise = self.to_variable(torch.randn(self.sample_size, self.z_dim)) fake_images = self.generator(noise) sample_path = os.path.join(self.sample_path, 'fake_samples-final.png') torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12) print("Saved sampled images to '%s'" % sample_path)
for g_iter in range(1): # generator optimizer_G.zero_grad() gen_imgs = torch.cat((img, output), 1) loss_G = -torch.mean(D(gen_imgs)) loss_focal = criterion(output, label) loss = loss_focal + loss_G loss.backward() optimizer_G.step() train_loss += loss_focal.item() / trainSize G.eval(), D.eval() with torch.no_grad(): for batch in val_dataloader: img_v, label_v = batch[0].to(device), batch[1].to(device) output_v = G(img_v) loss = criterion(output_v, label_v) val_loss += loss.item() / valSize loss_track.append((train_loss, loss_G, loss_D, val_loss)) torch.save(loss_track, 'checkpoint_GAN/loss.pth') print( '[{:4d}/{}], tr_ls: {:.5f}, G_ls: {:.5f}, D_ls: {:.5f}, te_ls: {:.5f}'. format(epoch + 1, epoch_num, train_loss, loss_G, loss_D, val_loss))