dataset = ClassImageLoader(paths=sep_data, transform=transform, inf=True) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True) random_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True) # load model transfer = Conditional_UNet(num_classes=args.num_classes) sd = torch.load(args.cp_path) transfer.load_state_dict(sd['inference']) transfer.eval() classifer = torch.load(args.classifer_path) classifer = nn.Sequential(classifer, nn.Softmax(dim=1)) classifer.eval() transfer.cuda() classifer.cuda() bs = args.batch_size labels = torch.as_tensor(np.arange(args.num_classes, dtype=np.int64)) onehot = torch.eye(args.num_classes)[labels].to('cuda') cls_li = [] vec_li = []
class WeatherTransfer(object): def __init__(self, args): self.args = args self.batch_size = args.batch_size self.global_step = 0 os.makedirs(os.path.join(args.save_dir, args.name), exist_ok=True) comment = '_lr-{}_bs-{}_ne-{}_name-{}'.format(args.lr, args.batch_size, args.num_epoch, args.name) self.writer = SummaryWriter(comment=comment) # Consts self.real = Variable_Float(1., self.batch_size) self.fake = Variable_Float(0., self.batch_size) self.lmda = 0. if args.augmentation: train_transform = transforms.Compose([ transforms.RandomRotation(10), transforms.RandomResizedCrop(args.input_size), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.3, hue=0), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) else: train_transform = transforms.Compose([ transforms.Resize((args.input_size, ) * 2), transforms.RandomRotation(10), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) test_transform = transforms.Compose([ transforms.Resize((args.input_size, ) * 2), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) self.cols = ['clouds', 'temp', 'humidity', 'pressure', 'windspeed'] self.num_classes = len(self.cols) self.transform = {'train': train_transform, 'test': test_transform} self.train_set, self.test_set = self.load_data( varbose=True, image_only=args.image_only) self.build() def load_data(self, varbose=False, image_only=False, train_data_rate=0.7): print('Start loading image files...') if image_only: paths = glob(os.path.join(args.image_root, '*')) print('loaded {} data'.format(len(paths))) pivot = int(len(paths) * train_data_rate) paths_sep = {'train': paths[:pivot], 'test': paths[pivot:]} loader = lambda s: ImageLoader(paths_sep[s], transform=self.transform[s]) else: df = pd.read_pickle(args.pkl_path) temp = pd.read_pickle( '/mnt/fs2/2019/okada/from_nitta/parm_0.3/sepalated_data_wo-outlier.pkl' ) df_ = temp.loc[:, self.cols].fillna(0) df_mean = df_.mean() df_std = df_.std() df.loc[:, self.cols] = (df.loc[:, self.cols].fillna(0) - df_mean) / df_std print('loaded {} signals data'.format(len(df))) df_shuffle = df.sample(frac=1) df_sep = { 'train': df_shuffle[df_shuffle['mode'] == 'train'], 'test': df_shuffle[df_shuffle['mode'] == 'test'] } del df, df_shuffle, temp loader = lambda s: FlickrDataLoader(args.image_root, df_sep[s], self.cols, transform=self.transform[s]) train_set = loader('train') test_set = loader('test') print('train:{} test:{} sets have already loaded.'.format( len(train_set), len(test_set))) return train_set, test_set def build(self): args = self.args # Models print('Build Models...') self.inference = Conditional_UNet(num_classes=self.num_classes) self.discriminator = SNDisc(num_classes=self.num_classes) exist_cp = sorted(glob(os.path.join(args.save_dir, args.name, '*'))) if len(exist_cp) != 0: print('Load checkpoint:{}'.format(exist_cp[-1])) sd = torch.load(exist_cp[-1]) self.inference.load_state_dict(sd['inference']) self.discriminator.load_state_dict(sd['discriminator']) self.epoch = sd['epoch'] self.global_step = sd['global_step'] print('Success checkpoint loading!') else: print('Initialize training status.') self.epoch = 0 self.global_step = 0 self.estimator = torch.load(args.estimator_path) self.estimator.eval() # Models to CUDA [ i.cuda() for i in [self.inference, self.discriminator, self.estimator] ] # Optimizer self.g_opt = torch.optim.Adam(self.inference.parameters(), lr=args.lr, betas=(0.0, 0.999), weight_decay=args.lr / 20) self.d_opt = torch.optim.Adam(self.discriminator.parameters(), lr=args.lr, betas=(0.0, 0.999), weight_decay=args.lr / 20) # これらのloaderにsamplerは必要ないのか? self.train_loader = torch.utils.data.DataLoader( self.train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) if args.sampler: self.random_loader = torch.utils.data.DataLoader( self.train_set, batch_size=args.batch_size, sampler=ImbalancedDatasetSampler(self.train_set), drop_last=True, num_workers=args.num_workers) else: self.random_loader = torch.utils.data.DataLoader( self.train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) if not args.image_only: self.test_loader = torch.utils.data.DataLoader( self.test_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers) test_data_iter = iter(self.test_loader) self.test_random_sample = [ tuple(d.to('cuda') for d in test_data_iter.next()) for i in range(2) ] del test_data_iter, self.test_loader self.scalar_dict = {} self.image_dict = {} self.shift_lmda = lambda a, b: (1. - self.lmda) * a + self.lmda * b print('Build has been completed.') def update_inference(self, images, r_labels): # --- UPDATE(Inference) --- # self.g_opt.zero_grad() # for real pred_labels = self.estimator(images).detach() # real_res = self.discriminator(images, pred_labels) # real_d_out = real_res[0] # real_feat = real_res[3] fake_out = self.inference(images, r_labels) fake_c_out = self.estimator(fake_out) fake_res = self.discriminator(fake_out, r_labels) fake_d_out = fake_res[0] # fake_feat = fake_res[3] # Calc Generator Loss g_loss_adv = gen_hinge(fake_d_out) # Adversarial loss g_loss_l1 = l1_loss(fake_out, images) g_loss_w = pred_loss(fake_c_out, r_labels) # Weather prediction # abs_loss = torch.mean(torch.abs(fake_out - images), [1, 2, 3]) diff = torch.mean(torch.abs(fake_out - images), [1, 2, 3]) lmda = torch.mean(torch.abs(pred_labels - r_labels), 1) loss_con = torch.mean(diff / (lmda + 1e-7)) # Reconstraction loss lmda_con, lmda_w = (1, 1) g_loss = g_loss_adv + lmda_con * loss_con + lmda_w * g_loss_w g_loss.backward() self.g_opt.step() self.scalar_dict.update({ 'losses/g_loss/train': g_loss.item(), 'losses/g_loss_adv/train': g_loss_adv.item(), 'losses/g_loss_l1/train': g_loss_l1.item(), 'losses/g_loss_w/train': g_loss_w.item(), 'losses/loss_con/train': loss_con.item(), 'variables/lmda': self.lmda }) self.image_dict.update({ 'io/train': torch.cat([images, fake_out], dim=3), }) def update_discriminator(self, images, labels): # --- UPDATE(Discriminator) ---# self.d_opt.zero_grad() # for real real_c_out = self.estimator(images) pred_labels = real_c_out.detach() real_d_out_pred = self.discriminator(images, pred_labels)[0] # for fake fake_out = self.inference(images, labels) fake_d_out = self.discriminator(fake_out.detach(), labels)[0] d_loss = dis_hinge(fake_d_out, real_d_out_pred) d_loss.backward() self.d_opt.step() self.scalar_dict.update({'losses/d_loss/train': d_loss.item()}) def evaluation(self): g_loss_l1_ = [] g_loss_adv_ = [] g_loss_w_ = [] fake_out_li = [] d_loss_ = [] # loss_con_ = [] images, labels = self.test_random_sample[0] # if not args.supervised: # labels_ = self.estimator(images).detach() blank = torch.zeros_like(images[0]).unsqueeze(0) ref_images, ref_labels = self.test_random_sample[1] for i in range(self.batch_size): with torch.no_grad(): if ref_labels is None: ref_labels = self.estimator(ref_images) ref_labels_expand = torch.cat([ref_labels[i]] * self.batch_size).view( -1, self.num_classes) fake_out_ = self.inference(images, ref_labels_expand) fake_c_out_ = self.estimator(fake_out_) # fake_d_out_ = self.discriminator(fake_out_, labels)[0] # Dへの入力はfake_out_ と re_labels_expandではないのか? real_d_out_ = self.discriminator(images, labels)[0] fake_d_out_ = self.discriminator(fake_out_, ref_labels_expand)[0] # diff = torch.mean(torch.abs(fake_out_ - images), [1, 2, 3]) # lmda = torch.mean(torch.abs(pred_labels_ - ref_labels_expand), 1) # loss_con_ = torch.mean(diff / (lmda + 1e-7)) fake_out_li.append(fake_out_) g_loss_adv_.append(gen_hinge(fake_d_out_).item()) g_loss_l1_.append(l1_loss(fake_out_, images).item()) g_loss_w_.append(pred_loss(fake_c_out_, ref_labels_expand).item()) d_loss_.append(dis_hinge(fake_d_out_, real_d_out_).item()) # loss_con_.append(torch.mean(diff / (lmda + 1e-7).item()) # --- WRITING SUMMARY ---# self.scalar_dict.update({ 'losses/g_loss_adv/test': np.mean(g_loss_adv_), 'losses/g_loss_l1/test': np.mean(g_loss_l1_), 'losses/g_loss_w/test': np.mean(g_loss_w_), 'losses/d_loss/test': np.mean(d_loss_) }) ref_img = torch.cat([blank] + list(torch.split(ref_images, 1)), dim=3) in_out_img = torch.cat([images] + fake_out_li, dim=3) res_img = torch.cat([ref_img, in_out_img], dim=0) self.image_dict.update({ 'images/test': res_img, }) def update_summary(self): # Summarize for k, v in self.scalar_dict.items(): spk = k.rsplit('/', 1) self.writer.add_scalars(spk[0], {spk[1]: v}, self.global_step) for k, v in self.image_dict.items(): grid = make_grid(v, nrow=1, normalize=True, scale_each=True) self.writer.add_image(k, grid, self.global_step) def train(self): args = self.args # train setting eval_per_step = 1000 display_per_step = 1000 save_per_epoch = 5 self.all_step = args.num_epoch * len(self.train_set) // self.batch_size tqdm_iter = trange(args.num_epoch, desc='Training', leave=True) for epoch in tqdm_iter: if epoch > 0: self.epoch += 1 for i, (data, rand_data) in enumerate( zip(self.train_loader, self.random_loader)): self.global_step += 1 if self.global_step % eval_per_step == 0: out_path = os.path.join(args.save_dir, args.name, (args.name + '_e{:04d}_s{}.pt').format( self.epoch, self.global_step)) state_dict = { 'inference': self.inference.state_dict(), 'discriminator': self.discriminator.state_dict(), 'epoch': self.epoch, 'global_step': self.global_step } torch.save(state_dict, out_path) tqdm_iter.set_description('Training [ {} step ]'.format( self.global_step)) if args.lmda: self.lmda = args.lmda else: self.lmda = self.global_step / self.all_step images, d_ = (d.to('cuda') for d in data) rand_images, r_ = (d.to('cuda') for d in rand_data) rand_labels = self.estimator(rand_images).detach() if images.size(0) != self.batch_size: continue self.update_discriminator(images, rand_labels) if self.global_step % args.GD_train_ratio == 0: self.update_inference(images, rand_labels) # --- EVALUATION ---# if (self.global_step % eval_per_step == 0) and not args.image_only: self.evaluation() # --- UPDATE SUMMARY ---# if self.global_step % display_per_step == 0: self.update_summary() print('Done: training')