class Tester(object): def __init__(self, args): super(Tester, self).__init__() self.args = args self.model = AutoEncoder(args) self.model.load_state_dict(torch.load(args.checkpoint)) self.model.cuda() self.model.eval() self.result = {} self.train_dataset = CUBDataset(split='train') self.test_dataset = CUBDataset(split='test') self.val_dataset = CUBDataset(split='val') self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=args.batch_size) self.test_loader = DataLoader(dataset=self.test_dataset, batch_size=args.batch_size) self.val_loader = DataLoader(dataset=self.val_dataset, batch_size=100, shuffle=True) train_cls = self.train_dataset.get_classes('train') test_cls = self.test_dataset.get_classes('test') print("Load class") print(train_cls) print(test_cls) self.zsl = ZSLPrediction(train_cls, test_cls) # def tSNE(self): def conse_prediction(self, mode='test'): def pred(recon_x, z_tilde, output): cls_score = output.detach().cpu().numpy() print(cls_score) pred = self.zsl.conse_wordembedding_predict( cls_score, self.args.conse_top_k) return pred self.get_features(mode=mode, pred_func=pred) if (mode + '_pred') in self.result: target = self.result[mode + '_label'] pred = self.result[mode + '_pred'] print(target) print(pred) acc = np.sum(target == pred) print(acc) total = target.shape[0] print(total) return acc / float(total) else: raise NotImplementedError def knn_prediction(self, mode='test'): self.get_features(mode=mode, pred_func=None) if (mode + '_feature') in self.result: features = self.result[mode + '_feature'] labels = self.result[mode + '_label'] print(labels) self.zsl.construct_nn(features, labels, k=5, metric='cosine', sample_num=5) pred = self.zsl.nn_predict(features) acc = np.sum(labels == pred) total = labels.shape[0] return acc / float(total) else: raise NotImplementedError def tSNE(self, mode='train'): self.get_features(mode=mode, pred_func=None) total_num = self.result[mode + '_feature'].shape[0] random_index = np.random.permutation(total_num) random_index = random_index[:30] self.zsl.tSNE_visualization(self.result[mode+'_feature'][random_index,:], \ self.result[mode+'_label'][random_index], \ mode=mode, file_name= self.args.tsne_out) def get_features(self, mode='test', pred_func=None): self.model.eval() if pred_func is None and (mode + '_feature') in self.result: print("Use cached result") return if pred_func is not None and (mode + '_pred') in self.result: print("Use cached result") return if mode == 'train': loader = self.train_loader elif mode == 'test': loader = self.test_loader all_z = [] all_label = [] all_pred = [] for data in tqdm(loader): # if idx == 3: # break images = Variable(data['image64_crop'].cuda()) target = Variable(data['class_id'].cuda()) recon_x, z_tilde, output = self.model(images) target = target.detach().cpu().numpy() output = F.softmax(output, dim=1) all_label.append(target) all_z.append(z_tilde.detach().cpu().numpy()) if pred_func is not None: pred = pred_func(recon_x, z_tilde, output) all_pred.append(pred) self.result[mode + '_feature'] = np.vstack(all_z) # all features # print(all_label) self.result[mode + '_label'] = np.hstack(all_label) # all test label if pred_func is not None: self.result[mode + '_pred'] = np.hstack(all_pred) print(self.result[mode + '_pred'].shape) print(self.result[mode + '_feature'].shape) print(self.result[mode + '_label'].shape) def validation_recon(self): self.model.eval() for idx, data in enumerate(self.val_loader): if idx == 1: break images = Variable(data['image64_crop'].cuda()) recon_x, z_tilde, output = self.model(images) all_recon_images = recon_x.detach().cpu().numpy() #N x 3 x 64 x 64 all_origi_images = data['image64_crop'].numpy() #N x 3 x 64 x 64 for i in range(all_recon_images.shape[0]): imsave( './recon/recon' + str(i) + '.png', np.transpose(np.squeeze(all_origi_images[i, :, :, :]), [1, 2, 0])) imsave( './recon/orig' + str(i) + '.png', np.transpose(np.squeeze(all_recon_images[i, :, :, :]), [1, 2, 0])) def test_nn_image(self): self.get_features(mode='test', pred_func=None) self.get_features(mode='train', pred_func=None) N = 100 random_index = np.random.permutation( self.result['test_feature'].shape[0])[:N] from sklearn.neighbors import NearestNeighbors neigh = NearestNeighbors() neigh.fit(self.result['train_feature']) test_feature = self.result['test_feature'][random_index, :] _, pred_index = neigh.kneighbors(test_feature, 1) for i in range(N): test_index = random_index[i] data = self.test_dataset[test_index] image = data['image64_crop'].numpy() #1 x 3 x 64 x 64 print(image.shape) imsave('./nn_image/test' + str(i) + '.png', np.transpose(np.squeeze(image), [1, 2, 0])) train_index = pred_index[i][0] print(train_index) data = self.train_dataset[train_index] image = data['image64_crop'].numpy() #1 x 3 x 64 x 64 print(image.shape) imsave('./nn_image/train' + str(i) + '.png', np.transpose(np.squeeze(image), [1, 2, 0]))
class Trainer(object): def __init__(self, args): # load network self.G = AutoEncoder(args) self.D = Discriminator(args) self.G.weight_init() self.D.weight_init() self.G.cuda() self.D.cuda() self.criterion = nn.MSELoss() # load data self.train_dataset = CUBDataset(split='train') self.valid_dataset = CUBDataset(split='val') self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=args.batch_size) self.valid_loader = DataLoader(dataset=self.valid_dataset, batch_size=args.batch_size) # Optimizers self.G_optim = optim.Adam(self.G.parameters(), lr = args.lr_G) self.D_optim = optim.Adam(self.D.parameters(), lr = 0.5 * args.lr_D) self.G_scheduler = StepLR(self.G_optim, step_size=30, gamma=0.5) self.D_scheduler = StepLR(self.D_optim, step_size=30, gamma=0.5) # Parameters self.epochs = args.epochs self.batch_size = args.batch_size self.z_var = args.z_var self.sigma = args.sigma self.lambda_1 = args.lambda_1 self.lambda_2 = args.lambda_2 log_dir = os.path.join(args.log_dir, datetime.now().strftime("%m_%d_%H_%M_%S")) # if not os.path.isdir(log_dir): # os.makedirs(log_dir) self.writter = SummaryWriter(log_dir) def train(self): global_step = 0 self.G.train() self.D.train() ones = Variable(torch.ones(self.batch_size, 1).cuda()) zeros = Variable(torch.zeros(self.batch_size, 1).cuda()) for epoch in range(self.epochs): self.G_scheduler.step() self.D_scheduler.step() print("training epoch {}".format(epoch)) all_num = 0.0 acc_num = 0.0 images_index = 0 for data in tqdm(self.train_loader): images = Variable(data['image64'].cuda()) target_image = Variable(data['image64'].cuda()) target = Variable(data['class_id'].cuda()) recon_x, z_tilde, output = self.G(images) z = Variable((self.sigma*torch.randn(z_tilde.size())).cuda()) log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1) ones = Variable(torch.ones(images.size()[0], 1).cuda()) zeros = Variable(torch.zeros(images.size()[0], 1).cuda()) # ======== Train Discriminator ======== # D_z = self.D(z) D_z_tilde = self.D(z_tilde) D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \ F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros) total_D_loss = self.lambda_1*D_loss self.D_optim.zero_grad() total_D_loss.backward(retain_graph=True) self.D_optim.step() # ======== Train Generator ======== # recon_loss = F.mse_loss(recon_x, target_image, reduction='sum').div(self.batch_size) G_loss = F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, ones) class_loss = F.cross_entropy(output, target) total_G_loss = recon_loss + self.lambda_1*G_loss + self.lambda_2*class_loss self.G_optim.zero_grad() total_G_loss.backward() self.G_optim.step() # ======== Compute Classification Accuracy ======== # values, indices = torch.max(output, 1) acc_num += torch.sum((indices == target)).cpu().item() all_num += len(target) # ======== Log by TensorBoardX global_step += 1 if (global_step + 1) % 10 == 0: self.writter.add_scalar('train/recon_loss', recon_loss.cpu().item(), global_step) self.writter.add_scalar('train/G_loss', G_loss.cpu().item(), global_step) self.writter.add_scalar('train/D_loss', D_loss.cpu().item(), global_step) self.writter.add_scalar('train/classify_loss', class_loss.cpu().item(), global_step) self.writter.add_scalar('train/total_G_loss', total_G_loss.cpu().item(), global_step) self.writter.add_scalar('train/acc', acc_num/all_num, global_step) if images_index < 5 and torch.rand(1) < 0.5: self.writter.add_image('train_output_{}'.format(images_index), recon_x[0], global_step) self.writter.add_image('train_target_{}'.format(images_index), target_image[0], global_step) images_index += 1 if epoch % 2 == 0: self.validate(global_step) def validate(self, global_step): self.G.eval() self.D.eval() acc_num = 0.0 all_num = 0.0 recon_loss = 0.0 images_index = 0 for data in tqdm(self.valid_loader): images = Variable(data['image64'].cuda()) target_image = Variable(data['image64'].cuda()) target = Variable(data['class_id'].cuda()) recon_x, z_tilde, output = self.G(images) values, indices = torch.max(output, 1) acc_num += torch.sum((indices == target)).cpu().item() all_num += len(target) recon_loss += F.mse_loss(recon_x, target_image, reduction='sum').cpu().item() if images_index < 5: self.writter.add_image('valid_output_{}'.format(images_index), recon_x[0], global_step) self.writter.add_image('valid_target_{}'.format(images_index), target_image[0], global_step) images_index += 1 self.writter.add_scalar('valid/acc', acc_num/all_num, global_step) self.writter.add_scalar('valid/recon_loss', recon_loss/all_num, global_step)