return encoded_list, label_list # Train model if args.model != 'PCA': log = [] t_total = time.time() for epoch in range(args.epochs): log_epoch = train(epoch) log.append(log_epoch) print("Optimization Finished!") print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) # Test encoded_list, label_list = test() if args.encoded is not None: np.savez(args.encoded, encoded=encoded_list, label=label_list) # Save File if args.log is not None: np.savetxt(args.log, log) if args.o is not None: torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, args.o)
def train(): from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm import lpips from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3 real_features = None inception = load_patched_inception_v3().cuda() inception.eval() percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) saved_image_folder = saved_model_folder = None log_file_path = None if saved_image_folder is None: saved_image_folder, saved_model_folder = make_folders( SAVE_FOLDER, 'GAN_' + TRIAL_NAME) log_file_path = saved_image_folder + '/../gan_log.txt' log_file = open(log_file_path, 'w') log_file.close() dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_GAN, rand_crop=True) print('the dataset contains %d images.' % len(dataset)) dataloader = iter( DataLoader(dataset, BATCH_SIZE_GAN, sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True)) from datasets import ImageFolder from datasets import trans_maker_augment as trans_maker dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512)) dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512)) net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS) if PRETRAINED_AE_PATH is None: PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE else: from config import PRETRAINED_AE_ITER PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER net_ae.load_state_dicts(PRETRAINED_AE_PATH) net_ae.cuda() net_ae.eval() RefineGenerator = None if DATA_NAME == 'celeba': from models import RefineGenerator_face as RefineGenerator elif DATA_NAME == 'art' or DATA_NAME == 'shoe': from models import RefineGenerator_art as RefineGenerator net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda() net_id = Discriminator(nc=3).cuda( ) # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024 if MULTI_GPU: net_ae = nn.DataParallel(net_ae) net_ig = nn.DataParallel(net_ig) net_id = nn.DataParallel(net_id) net_ig_ema = copy_G_params(net_ig) opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999)) opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999)) if GAN_CKECKPOINT is not None: ckpt = torch.load(GAN_CKECKPOINT) net_ig.load_state_dict(ckpt['ig']) net_id.load_state_dict(ckpt['id']) net_ig_ema = ckpt['ig_ema'] opt_ig.load_state_dict(ckpt['opt_ig']) opt_id.load_state_dict(ckpt['opt_id']) ## create a log file losses_g_img = AverageMeter() losses_d_img = AverageMeter() losses_mse = AverageMeter() losses_rec_s = AverageMeter() losses_rec_ae = AverageMeter() fixed_skt = fixed_rgb = fixed_perm = None fid = [[0, 0]] for epoch in range(EPOCH_GAN): for iteration in tqdm(range(10000)): rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader) rgb_img = rgb_img.cuda() rd = random.randint(0, 3) if rd == 0: skt_img = skt_img_1.cuda() elif rd == 1: skt_img = skt_img_2.cuda() else: skt_img = skt_img_3.cuda() if iteration == 0: fixed_skt = skt_img_3[:8].clone().cuda() fixed_rgb = rgb_img[:8].clone() fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda') ### 1. train D gimg_ae, style_feats = net_ae(skt_img, rgb_img) g_image = net_ig(gimg_ae, style_feats) pred_r = net_id(rgb_img) pred_f = net_id(g_image.detach()) loss_d = d_hinge_loss(pred_r, pred_f) net_id.zero_grad() loss_d.backward() opt_id.step() loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss( gimg_ae, rgb_img) losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN) ### 2. train G pred_g = net_id(g_image) loss_g = g_hinge_loss(pred_g) if DATA_NAME == 'shoe': loss_mse = 10 * (F.l1_loss(g_image, rgb_img) + F.mse_loss(g_image, rgb_img)) else: loss_mse = 10 * percept( F.adaptive_avg_pool2d(g_image, output_size=256), F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum() losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN) loss_all = loss_g + loss_mse if DATA_NAME == 'shoe': ### the grey image reconstruction perm = true_randperm(BATCH_SIZE_GAN) img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm]) gimg_grey = net_ig(img_ae_perm, style_feats_perm) gimg_grey = gimg_grey.mean(dim=1, keepdim=True) real_grey = rgb_img.mean(dim=1, keepdim=True) loss_rec_grey = F.mse_loss(gimg_grey, real_grey) loss_all += 10 * loss_rec_grey net_ig.zero_grad() loss_all.backward() opt_ig.step() for p, avg_p in zip(net_ig.parameters(), net_ig_ema): avg_p.mul_(0.999).add_(p.data, alpha=0.001) ### 3. logging losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN) losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN) if iteration % SAVE_IMAGE_INTERVAL == 0: #show the current images with torch.no_grad(): backup_para_g = copy_G_params(net_ig) load_params(net_ig, net_ig_ema) gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb) gmatch = net_ig(gimg_ae, style_feats) gimg_ae_perm, style_feats = net_ae(fixed_skt, fixed_rgb[fixed_perm]) gmismatch = net_ig(gimg_ae_perm, style_feats) gimg = torch.cat([ F.interpolate(fixed_rgb, IM_SIZE_GAN), F.interpolate(fixed_skt.repeat(1, 3, 1, 1), IM_SIZE_GAN), gmatch, F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch, F.interpolate(gimg_ae_perm, IM_SIZE_GAN) ]) vutils.save_image( gimg, f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg', normalize=True, range=(-1, 1)) del gimg make_matrix( dataset_rgb, dataset_skt, net_ae, net_ig, 5, f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg' ) load_params(net_ig, backup_para_g) if iteration % LOG_INTERVAL == 0: log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f} D: {losses_d_img.avg:.4f} MSE: {losses_mse.avg:.4f} Rec: {losses_rec_s.avg:.5f} FID: {fid:.4f}'.format( epoch, iteration, losses_g_img=losses_g_img, losses_d_img=losses_d_img, losses_mse=losses_mse, losses_rec_s=losses_rec_s, fid=fid[-1][0]) print(log_msg) print('%.5f' % (losses_rec_ae.avg)) if log_file_path is not None: log_file = open(log_file_path, 'a') log_file.write(log_msg + '\n') log_file.close() losses_g_img.reset() losses_d_img.reset() losses_mse.reset() losses_rec_s.reset() losses_rec_ae.reset() if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000: print('Saving history model') torch.save( { 'ig': net_ig.state_dict(), 'id': net_id.state_dict(), 'ae': net_ae.state_dict(), 'ig_ema': net_ig_ema, 'opt_ig': opt_ig.state_dict(), 'opt_id': opt_id.state_dict(), }, '%s/%d.pth' % (saved_model_folder, epoch)) if iteration % FID_INTERVAL == 0 and iteration > 1: print("calculating FID ...") fid_batch_images = FID_BATCH_NBR if real_features is None: if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)): real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) else: real_features = extract_feature_from_generator_fn( real_image_loader(dataloader, n_batches=fid_batch_images), inception) real_mean = np.mean(real_features, 0) real_cov = np.cov(real_features, rowvar=False) pickle.dump( { 'feats': real_features, 'mean': real_mean, 'cov': real_cov }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb')) real_features = pickle.load( open('%s_fid_feats.npy' % (DATA_NAME), 'rb')) sample_features = extract_feature_from_generator_fn( image_generator(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid = calc_fid(sample_features, real_mean=real_features['mean'], real_cov=real_features['cov']) sample_features_perm = extract_feature_from_generator_fn( image_generator_perm(dataset, net_ae, net_ig, n_batches=fid_batch_images), inception, total=fid_batch_images) cur_fid_perm = calc_fid(sample_features_perm, real_mean=real_features['mean'], real_cov=real_features['cov']) fid.append([cur_fid, cur_fid_perm]) print('fid:', fid) if log_file_path is not None: log_file = open(log_file_path, 'a') log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1]) log_file.write(log_msg + '\n') log_file.close()
class BiomeAE(): def __init__(self, args): if args.model in ["BiomeAEL0"]: self.mlp_type = "L0" else: self.mlp_type = None self.model_alias = args.model_alias self.model= args.model self.snap_loc = os.path.join(args.vis_dir, "snap.pt") #tl.configure("runs/ds.{}".format(model_alias)) #tl.log_value(model_alias, 0) """ no stat file needed for now stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset) stat_path = os.path.join( output_dir, '%s.pkl' % (stat_alias) ) with open(stat_path,'rb') as sf: data_stats = pickle.load(sf) """ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num self.predictor = None torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def get_transformation(self): return None def loss_fn(self, recon_x, x, mean, log_var): if self.model in ["BiomeAE","BiomeAESnip"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x)) if self.model in ["BiomeAEL0"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x))+self.predictor.regularization() elif self.model =="BiomeVAE": BCE = torch.nn.functional.binary_cross_entropy( recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return (BCE + KLD) / x.size(0) def param_l0(self): return self.predictor.param_l0() def init_fit(self, X1_train, X2_train, y_train, X1_val, X2_val, y_val, args, ): self.train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size) self.test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size) self.predictor = AE( encoder_layer_sizes=[X1_train.shape[1]], latent_size=args.latent_size, decoder_layer_sizes=[X2_train.shape[1]], activation=args.activation, batch_norm= args.batch_norm, dropout=args.dropout, mlp_type=self.mlp_type, conditional=args.conditional, num_labels=10 if args.conditional else 0).to(self.device) self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=args.learning_rate) self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8) def train(self, args): if args.contr: print("Loading from ", self.snap_loc) loaded_model_para = torch.load(self.snap_loc) self.predictor.load_state_dict(loaded_model_para) t = 0 logs = defaultdict(list) iterations_per_epoch = len(self.train_loader.dataset) / args.batch_size num_iterations = int(iterations_per_epoch * args.epochs) for epoch in range(args.epochs): tracker_epoch = defaultdict(lambda: defaultdict(dict)) for iteration, (x1, x2, y) in enumerate(self.train_loader): t+=1 x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) for i, yi in enumerate(y): id = len(tracker_epoch) tracker_epoch[id]['x'] = z[i, 0].item() tracker_epoch[id]['y'] = z[i, 1].item() tracker_epoch[id]['label'] = yi.item() loss = self.loss_fn(x2_hat, x2, mean, log_var) self.optimizer.zero_grad() loss.backward() if (t + 1) % int(num_iterations / 10) == 0: self.scheduler.step() self.optimizer.step() #enforce non-negative if args.nonneg_weight: for layer in self.predictor.modules(): if isinstance(layer, nn.Linear): layer.weight.data.clamp_(0.0) logs['loss'].append(loss.item()) if iteration % args.print_every == 0 or iteration == len(self.train_loader)-1: print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( epoch, args.epochs, iteration, len(self.train_loader)-1, loss.item())) if args.model =="VAE": if args.conditional: c = torch.arange(0, 10).long().unsqueeze(1) x = self.predictor.inference(n=c.size(0), c=c) else: x = self.predictor.inference(n=10) if not args.contr: print("Saving to ", self.snap_loc) torch.save(self.predictor.state_dict(), self.snap_loc) def fit(self,X1_train, X2_train, y_train, X1_val, X2_val, y_val, args,): self.init_fit(X1_train, X2_train, y_train, X1_val, X2_val, y_val, args) self.train(args) def get_graph(self): """ return nodes and weights :return: """ nodes = [] weights = [] for l, layer in enumerate(self.predictor.modules()): if isinstance(layer, nn.Linear): lin_layer =layer nodes.append(["%s"%(x) for x in list(range(lin_layer.in_features))]) weights.append(lin_layer.weight.detach().cpu().numpy().T) nodes.append(["%s"%(x) for x in list(range(lin_layer.out_features))]) #last linear layer return (nodes, weights) def predict(self,X1_val, X2_val, y_val, args): #Batch test x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) val_loss = self.loss_fn( x2_hat, x2, mean, log_var) print("val_loss: {:9.4f}", val_loss.item()) return x2_hat.detach().cpu().numpy() def transform(self,X1_val, X2_val, y_val, args): x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to( self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) return z.detach().cpu().numpy() def get_influence_matrix(self): return self.predictor.get_influence_matrix()
def data_aug(data, lr=0.001, epoch=800, batch_size=128): folder = 'data_aug' save_path = f'{folder}/data_augment.csv' clean_file(save_path) store_e = [700, 750, 800] if not os.path.exists(folder): os.makedirs(folder) else: for i in store_e: result = test(data, folder, i) return result train_loss_curve = [] valid_loss_curve = [] # load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AE() model = model.to(device) model.train() dataset = AEDataset(data) train_size = int(0.85 * len(dataset)) valid_size = len(dataset) - train_size train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size]) train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True) # loss function and optimizer # can change loss function and optimizer you want criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) best = 100 # start training for e in range(epoch): train_loss = 0.0 valid_loss = 0.0 print(f'\nEpoch: {e+1}/{epoch}') print('-' * len(f'Epoch: {e+1}/{epoch}')) # tqdm to disply progress bar for inputs in tqdm(train_dataloader): # data from data_loader inputs = inputs.float().to(device) outputs = model(inputs, device) loss = criterion(outputs, inputs) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() for inputs in tqdm(valid_dataloader): # data from data_loader inputs = inputs.float().to(device) outputs = model(inputs, device) # MSE loss loss = criterion(outputs, inputs) # loss calculate valid_loss += loss.item() # save the best model weights as .pth file loss_epoch = train_loss / len(train_dataset) valid_loss_epoch = valid_loss / len(valid_dataset) # if valid_loss_epoch < best : # best = valid_loss_epoch # torch.save(model.state_dict(), 'data_aug.pth') if e in store_e: torch.save(model.state_dict(), f'{folder}/ep{e}data_aug.pth') print(f"training in epoch {e},start augment data!!") result = test(data, folder, e) print(f'Training loss: {loss_epoch:.4f}') print(f'Valid loss: {valid_loss_epoch:.4f}') # save loss every epoch train_loss_curve.append(loss_epoch) valid_loss_curve.append(valid_loss_epoch) # generate training curve # visualize(train_loss_curve,valid_loss_curve, 'Data Augmentation') return result