def compute_update_dir(self, p_d): self.gen.eval() self.dis.eval() self.classifier.eval() i = 0 _dir = 0 with torch.no_grad(): while i < 10: true_data, _ = p_d.next() true_data = tensor2Var(true_data) noise = create_noise(true_data.size(0), self.args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise).detach() true_out = self.dis(self.classifier(true_data, 'feat'), 'critic') gen_out = self.dis(self.classifier(gen_data, 'feat'), 'critic') _dir += (-true_out.mean() + gen_out.mean() - 1 * 2 * self.ps.alpha).item() # self.ps.update_dir(-1 * ((-real_output_c.mean() + fake_output_c.mean()).item() - \ # 0.5 * 2 * self.ps.alpha)) i += 1 print(true_out.mean().item(), gen_out.mean().item()) self.gen.train() self.dis.train() self.classifier.train() return -1 * np.sign(_dir)
def construct_p_d(self, data_1, data_2): beta = self.beta_2 beta_1 = self.beta_1 beta_2 = self.beta_2 if self._type == 'normal': noise = torch.FloatTensor(data_1.size()).normal_(0, beta) elif self._type == 'uniform': noise = torch.FloatTensor(data_1.size()).uniform_(-beta, beta) elif self._type == 'inter': if len(data_1.size()) == 2: _beta = torch.FloatTensor(data_1.size(0), 1).uniform_(beta_1, beta_2) elif len(data_1.size()) == 4: _beta = torch.FloatTensor(data_1.size(0), 1, 1, 1).uniform_(beta_1, beta_2) _beta = tensor2Var(_beta) out = _beta * data_1 + (1 - _beta) * data_2 return out elif self._type == 'huge_normal': noise = torch.randn(data_1.size()) * (1 + beta) return data_1 + tensor2Var(noise)
def train_c(self, train_loader, semi_weight): args = self.args # set_require_grad(self.classifier, requires_grad=True) # standard classification loss lab_data, lab_labels = train_loader.next() lab_data, lab_labels = tensor2Var(lab_data), tensor2Var(lab_labels) noise = create_noise(lab_data.size(0), args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise).detach() lab_logits = self.classifier(lab_data, 'class') gen_logits = self.classifier(gen_data, 'class') lab_loss = F.cross_entropy(lab_logits, lab_labels) gen_prob = F.softmax(gen_logits, dim=1) entropy = -(gen_prob * torch.log(gen_prob + 1e-8)).sum(1).mean() c_loss = lab_loss + semi_weight * entropy self.classifier_opt.zero_grad() c_loss.backward() self.classifier_opt.step() return lab_loss.cpu().item(), entropy.cpu().item(), c_loss.cpu().item()
def eval(self, data_loader): self.gen.eval() self.dis.eval() self.classifier.eval() loss, incorrect, cnt = 0, 0, 0 total_num = 0 max_unl_acc, max_gen_acc = 0, 0 with torch.no_grad(): for i, (data, labels) in enumerate(data_loader.get_iter()): data, labels = tensor2Var(data), tensor2Var(labels) noise = create_noise(data.size(0), self.args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise).detach() gen_logits = self.classifier(gen_data, 'class') pred_logits = self.classifier(data, 'class') labels = labels.view(-1) loss += F.cross_entropy(pred_logits, labels).item() * data.size(0) cnt += 1 total_num += data.size(0) incorrect += torch.ne(torch.max(pred_logits, 1)[1], labels).float().sum().item() max_unl_acc += torch.sum(pred_logits.max(1)[0].detach().gt(0.0).float()).item() max_gen_acc += torch.sum(gen_logits.max(1)[0].detach().lt(0.0).float()).item() return loss / total_num, incorrect, total_num, max_unl_acc / total_num, max_gen_acc / total_num
def calc_gradient_penalty(net, real_data, fake_data): alpha = torch.FloatTensor(real_data.size(0), 1, 1, 1).uniform_(0, 1) alpha = tensor2Var(alpha) interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates.requires_grad_(True) disc_interpolates = net(interpolates, 'critic') ones = tensor2Var(torch.ones(disc_interpolates.size())) gradients = grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=ones, create_graph=True, retain_graph=True, only_inputs=True)[0] while len(gradients.size()) > 1: gradients = gradients.norm(2, dim=(len(gradients.size()) - 1)) gradient_penalty = ((gradients - 1.0)**2).mean() return gradient_penalty
def get_data(self): data_1, _ = self.p_d_1.next() data_2, _ = self.p_d_2.next() data_1 = tensor2Var(data_1) data_2 = tensor2Var(data_2) return data_1, data_2
def train_c(self, labeled_loader, unlabeled_loader): args = self.args set_require_grad(self.classifier, requires_grad=True) # standard classification loss lab_data, lab_labels = labeled_loader.next() lab_data, lab_labels = tensor2Var(lab_data), tensor2Var(lab_labels) lab_labels = lab_labels.view(-1) unl_data, _ = unlabeled_loader.next() unl_data = tensor2Var(unl_data) noise = create_noise(unl_data.size(0), args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise).detach() lab_logits = self.classifier(lab_data, 'class') unl_logits = self.classifier(unl_data, 'class') gen_logits = self.classifier(gen_data, 'class') lab_loss = F.cross_entropy(lab_logits, lab_labels) unl_logsumexp = log_sum_exp(unl_logits) gen_logsumexp = log_sum_exp(gen_logits) unl_acc = torch.mean(torch.sigmoid(unl_logsumexp.detach()).gt(0.5).float()) gen_acc = torch.mean(torch.sigmoid(gen_logsumexp.detach()).lt(0.5).float()) # This is the typical GAN cost, where sumexp(logits) is seen as the input to the sigmoid true_loss = - 0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean(F.softplus(unl_logsumexp)) fake_loss = 0.5 * torch.mean(F.softplus(gen_logsumexp)) # max_unl_acc = torch.mean(unl_logits.max(1)[0].detach().gt(0.0).float()) # max_gen_acc = torch.mean(gen_logits.max(1)[0].detach().gt(0.0).float()) unl_prob = F.softmax(unl_logits, dim=1) entropy = -(unl_prob * torch.log(unl_prob + 1e-8)).sum(1).mean() unl_loss = true_loss + fake_loss c_loss = lab_loss + args.lambda_gan * unl_loss + args.lambda_e * entropy if args.lambda_consistency > 0: unl_logits_2 = self.classifier(unl_data, 'class') unl_prob_2 = F.softmax(unl_logits_2, dim=1) consistency_loss = ((unl_prob - unl_prob_2) ** 2).mean() c_loss += args.lambda_consistency * consistency_loss if self.total_iter % 1000 == 0: print(consistency_loss) self.classifier_opt.zero_grad() c_loss.backward() self.classifier_opt.step() return lab_loss.cpu().item(), unl_loss.cpu().item(), entropy.cpu().item()
def reparameterize(self, mean_input, logvar_input): if self.training: std = torch.exp(0.5 * logvar_input) epsilon = tensor2Var(torch.randn(std.size())) return mean_input + std * epsilon else: return mean_input
def visualize(self, train_loader): self.gen.eval() self.dis.eval() self.vae.eval() vis_size = 100 for i, (data, _) in enumerate(train_loader.get_iter()): data = tensor2Var(data) with torch.no_grad(): feat = self.vae.get_features(data) gen_images = self.gen(feat) gen_images = self.vae.decode(gen_images) break save_path = os.path.join(self.args.log_folder, '%d_gen_images.png' % self.total_iter) if self.args.dataset == 'mnist': # gen_images = gen_images.view(-1, self.args.n_channels, self.args.image_size, self.args.image_size) gen_images = gen_images * 0.5 + 0.5 # print(gen_images.shape) elif self.args.dataset == 'svhn' or self.args.dataset == 'cifar': gen_images = gen_images * 0.5 + 0.5 else: raise NotImplementedError vutils.save_image(gen_images.data.cpu(), save_path, nrow=10) save_path = os.path.join(self.args.log_folder, '%d_ori_images.png' % self.total_iter) vutils.save_image(data.data.cpu() * 0.5 + 0.5, save_path, nrow=10) self.vae.train()
def train_g(self): args = self.args set_require_grad(self.dis, False) noise = create_noise(args.train_batch_size, args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise) # get the feature of generated data gen_data = self.classifier(gen_data, 'feat') pullaway = pullaway_loss(gen_data) gen_out = self.dis(gen_data, 'critic') gen_loss = -gen_out.mean() g_loss = gen_loss + args.lambda_p * pullaway self.gen_opt.zero_grad() g_loss.backward() self.gen_opt.step() return gen_loss.cpu().item()
def train_g(self, p_d_2): args = self.args set_require_grad(self.dis, False) true_data, _ = p_d_2.next() true_data = tensor2Var(true_data) # noise = create_noise(args.train_batch_size, args.noise_size) # noise = tensor2Var(noise) true_data = self.vae.get_features(true_data) gen_data = self.gen(true_data) # self.vae.eval() # gen_data = self.vae.get_features(gen_data) # self.vae.train() gen_out = self.dis(gen_data, 'critic') feat_loss = ((gen_data - true_data).view(gen_data.shape[0], -1)**2).mean() gen_loss = -gen_out.mean() + args.lambda_feat * feat_loss g_loss = gen_loss self.gen_opt.zero_grad() g_loss.backward() self.gen_opt.step() return gen_loss.cpu().item(), feat_loss.cpu().item()
def visualize_generation(self, _iter): self.vae.eval() noise = torch.randn(self.sample_size, self.feature_size, 1, 1) with torch.no_grad(): noise_v = tensor2Var(noise) output = self.vae.decode(noise_v) tv.utils.save_image( output.data * 0.5 + 0.5, os.path.join(self.args.log_folder, 'generation_%d.png' % _iter))
def eval(self, data_loader): self.gen.eval() self.dis.eval() self.classifier.eval() loss, incorrect, cnt = 0, 0, 0 total_num = 0 with torch.no_grad(): for i, (data, labels) in enumerate(data_loader.get_iter()): data, labels = tensor2Var(data), tensor2Var(labels) pred_logits = self.classifier(data, 'class') loss += F.cross_entropy(pred_logits, labels).item() * data.size(0) cnt += 1 total_num += data.size(0) incorrect += torch.ne(torch.max(pred_logits, 1)[1], labels).float().sum().item() return loss / total_num, incorrect, total_num
def pullaway_loss(x1): norm_x1 = F.normalize(x1) N = x1.size(0) cosine_similarity_matrix = torch.matmul(norm_x1, norm_x1.t()) mask = torch.ones(cosine_similarity_matrix.size()) - torch.diag(torch.ones(N)) mask_v = tensor2Var(mask) cosine_similarity_matrix = (cosine_similarity_matrix * mask_v) ** 2 return cosine_similarity_matrix.sum() / (N * (N - 1))
def param_init_cnn(self, labeled_loader): def func_gen(flag): def func(m): if hasattr(m, 'init_mode'): setattr(m, 'init_mode', flag) return func images = [] for i in range(ceil(500 / self.args.train_batch_size)): lab_images, _ = labeled_loader.next() images.append(lab_images) images = torch.cat(images, 0) self.gen.apply(func_gen(True)) noise = tensor2Var(torch.Tensor(images.size(0), self.args.noise_size).uniform_()) gen_images = self.gen(noise) self.gen.apply(func_gen(False)) self.classifier.apply(func_gen(True)) logits = self.classifier(tensor2Var(images)) self.classifier.apply(func_gen(False))
def pullaway_loss_lp(x1, p=2): dist = torch.norm(x1[:, None] - x1, dim=2, p=p) dist = dist / dist.max() N = x1.size(0) mask = torch.ones(dist.size()) - torch.diag(torch.ones(N)) mask_v = tensor2Var(mask) dist = dist * mask_v return dist.sum() / (N * (N - 1))
def train_d(self, p_d, p_d_bar): args = self.args set_require_grad(self.dis, requires_grad=True) j = 0 # train discriminator multiples times per generator iteration while j < args.iter_c: j += 1 true_data_bar = p_d_bar.sample_feat(self.vae.get_features) true_data, _ = p_d.next() true_data = tensor2Var(true_data) true_data = self.vae.get_features(true_data) # noise = create_noise(true_data.size(0), args.noise_size) # noise = tensor2Var(noise) gen_data = self.gen(true_data).detach() # self.vae.eval() # gen_data = self.vae.get_features(gen_data) # self.vae.train() true_data_size = int(true_data.size(0) * self.ps.alpha) gen_size = true_data.size(0) - true_data_size # concatenate true and gen data true_gen_data = torch.cat( [true_data[:true_data_size], gen_data[:gen_size]], dim=0) true_data_bar_out = self.dis(true_data_bar, 'critic') true_gen_data_out = self.dis(true_gen_data, 'critic') dis_loss = -true_data_bar_out.mean() + true_gen_data_out.mean() d_loss = dis_loss + \ args.lambda_g * calc_gradient_penalty(self.dis, true_data_bar, true_gen_data) self.dis_opt.zero_grad() d_loss.backward() self.dis_opt.step() return -dis_loss.cpu().item()
def param_init_dnn(self, unlabeled_loader): def func_gen(flag): def func(m): if hasattr(m, 'init_mode'): setattr(m, 'init_mode', flag) return func images = [] for i in range(ceil(500 / self.args.train_batch_size)): unl_images, _ = unlabeled_loader.next() images.append(unl_images) images = torch.cat(images, 0) self.classifier.apply(func_gen(True)) logits = self.classifier(tensor2Var(images)) self.classifier.apply(func_gen(False))
def visualize(self): self.gen.eval() self.dis.eval() vis_size = 100 noise = create_noise(vis_size, self.args.noise_size) with torch.no_grad(): noise = tensor2Var(noise) gen_images = self.gen(noise) save_path = os.path.join(self.args.log_folder, 'gen_images_%d.png' % self.total_iter) if self.args.dataset == 'mnist': gen_images = gen_images.view(-1, self.args.n_channels, self.args.image_size, self.args.image_size) elif self.args.dataset == 'svhn' or self.args.dataset == 'cifar': gen_images = gen_images * 0.5 + 0.5 else: raise NotImplementedError vutils.save_image(gen_images.data.cpu(), save_path, nrow=10)
def eval_gen(self, gen_num): self.gen.eval() self.dis.eval() self.classifier.eval() loss = 0 total_num = 0 batch_size = self.args.train_batch_size with torch.no_grad(): while total_num < gen_num: if total_num + batch_size > gen_num: batch_size = gen_num - total_num noise = create_noise(batch_size, self.args.noise_size) noise = tensor2Var(noise) gen_images = self.gen(noise) gen_logits = self.classifier(gen_images, 'class') gen_prob = F.softmax(gen_logits, dim=1) loss += -(gen_prob * torch.log(gen_prob + 1e-8)).sum().item() total_num += batch_size return loss / total_num
def visualize_reconstruction(self, train_loader, _iter): self.vae.eval() with torch.no_grad(): for i, (data, _) in enumerate(train_loader.get_iter()): data_v = tensor2Var(data) reconstruct, mean, _ = self.vae(data_v) # noise_in = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-self.args.beta_1, self.args.beta_1)) # noise_out = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-self.args.beta_2, self.args.beta_2)) # reconstruct = self.vae.decode(torch.clamp(mean + noise_in, -1.0, 1.0)) # reconstruct_out = self.vae.decode(torch.clamp(mean + noise_out, -1.0, 1.0)) break tv.utils.save_image( data[:self.sample_size] * 0.5 + 0.5, os.path.join(self.args.log_folder, '%d_origin.png' % _iter)) tv.utils.save_image( reconstruct.data[:self.sample_size] * 0.5 + 0.5, os.path.join(self.args.log_folder, '%d_reconstruct.png' % _iter))
def finetune(self, tr_data_dict): args = self.args train_loader = tr_data_dict['train_loader'] p_d = tr_data_dict['p_d'] ###################################################################### ### start training total_iter = 0 best_loss = 1e8 best_err = 1e8 best_err_per = 1e8 begin_time = time() stop = 0 # for p in self.vae.conv_feature.parameters(): # p.requires_grad = False # for p in self.vae.mean_layer.parameters(): # p.requires_grad = False # for p in self.vae.std_layer.parameters(): # p.requires_grad = False scheduler = torch.optim.lr_scheduler.MultiStepLR( self.vae_opt, [300, 400, 500], 0.1) for epoch in range(args.max_epochs): epoch_ratio = float(epoch) / float(args.max_epochs) self.vae.train() for i, (lab_data, _) in enumerate(train_loader.get_iter()): lab_data = tensor2Var(lab_data) feat = self.vae.get_features(lab_data).detach() reconstruct = self.vae.decode(feat) gen_feat = self.gen(feat).detach() reconstruct_gen = self.vae.decode(gen_feat) gen_reconstruction_loss = torch.max( tensor2Var(torch.zeros(gen_feat.shape[0])), args.threshold - ((reconstruct_gen - lab_data)**2).view( gen_feat.shape[0], -1).mean(1)).mean() # gen_reconstruction_loss = ((reconstruct_gen - lab_data) ** 2).mean() # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True) reconstruction_loss = ((reconstruct - lab_data)**2).mean() loss = reconstruction_loss + args.lambda_out * gen_reconstruction_loss self.vae_opt.zero_grad() loss.backward() self.vae_opt.step() scheduler.step() if args.save_vae: save_dict = { 'total_iter': total_iter, 'vae_state_dict': self.vae.state_dict(), 'vae_opt': self.vae_opt.state_dict() } torch.save(save_dict, self.vae_checkpoint) self.logger.info('epoch: %d, iter: %d, spent: %.3f s' % (epoch, total_iter, time() - begin_time)) self.logger.info( '[train] loss: %.4f, reconst loss: %.4f, gen_reconst_loss: %.4f' % (loss.cpu().item(), reconstruction_loss.cpu().item(), gen_reconstruction_loss.cpu().item())) self.logger.info('--------') if epoch % 10 == 0: self.visualize_reconstruction(train_loader, epoch) self.visualize_generation(epoch) begin_time = time() total_iter += 1 self.total_iter += 1
def train_classifier(self, tr_data_dict): args = self.args set_require_grad(self.dis, requires_grad=False) set_require_grad(self.gen, requires_grad=False) set_require_grad(self.vae, requires_grad=False) self.gen.eval() self.dis.eval() self.vae.eval() train_loader = tr_data_dict['train_loader'] p_d_2 = tr_data_dict['p_d_2'] # valid_loader = tr_data_dict['valid_loader'] total_iter = 0 best_loss = 1e8 best_err = 1e8 best_err_per = 1e8 begin_time = time() stop = 0 self.visualize_embedding(p_d_2) for epoch in range(args.max_epochs): epoch_ratio = float(epoch) / float(args.max_epochs) # self.classifier_opt.param_groups[0]['lr'] = \ # max(args.min_lr, args.classifier_lr * max(0., min(3. * (1. - epoch_ratio), 1.))) self.classifier.train() for i, (lab_data, lab_labels) in enumerate(train_loader.get_iter()): lab_data, lab_labels = tensor2Var(lab_data), tensor2Var( lab_labels) noise = create_noise(lab_data.size(0), args.noise_size) noise = tensor2Var(noise) gen_data = self.gen(noise).detach() lab_data = self.vae.get_features(lab_data) gen_data = self.vae.get_features(gen_data) lab_logits = self.classifier(lab_data) gen_logits = self.classifier(gen_data) label_true = tensor2Var(torch.ones(lab_data.shape[0])) label_gen = tensor2Var(torch.zeros(gen_data.shape[0])) pred = torch.cat([lab_logits, gen_logits], dim=0) label = torch.cat([label_true, label_gen], dim=0) lab_loss = F.binary_cross_entropy(pred, label) self.classifier_opt.zero_grad() lab_loss.backward() self.classifier_opt.step() # if epoch % 10: # print(pred.shape) if args.save_classifier: save_dict = { 'total_iter': total_iter, 'classifier_state_dict': self.classifier.state_dict(), 'classifier_opt': self.classifier_opt.state_dict() } torch.save(save_dict, self.classifier_checkpoint) self.logger.info('epoch: %d, iter: %d, spent: %.3f s' % (epoch, total_iter, time() - begin_time)) self.logger.info('[train] loss: %.4f' % (lab_loss.cpu().item())) self.logger.info('--------') begin_time = time() total_iter += 1 self.total_iter += 1
def visualize_embedding(self, p_d_2): self.gen.eval() self.dis.eval() self.vae.eval() vis_size = 200 true_emb = [] gen_emb = [] cum_size = 0 with torch.no_grad(): for i, (data, _) in enumerate(p_d_2.get_iter()): data = tensor2Var(data) feat = self.vae.get_features(data) true_emb.append(feat.squeeze().cpu().numpy()) gen_feat = self.gen(feat) gen_emb.append(gen_feat.squeeze().cpu().numpy()) # gen_emb.append(gen_data.squeeze().cpu().numpy()) cum_size += data.shape[0] if cum_size >= vis_size: break true_emb = np.vstack(true_emb) gen_emb = np.vstack(gen_emb) # print(true_emb.shape, gen_emb.shape) tsne = sklearn.manifold.TSNE(2) all_emb = np.vstack([true_emb, gen_emb]) # print(all_emb.shape) all_emb = tsne.fit_transform(all_emb) size = all_emb.shape[0] // 2 true_emb = all_emb[:size] gen_emb = all_emb[size:] plt.clf() t = plt.scatter(true_emb[:, 0], true_emb[:, 1], label='true data') g = plt.scatter(gen_emb[:, 0], gen_emb[:, 1], label='gen data') plt.legend([t, g], ['true data', 'gen data']) save_path = os.path.join(self.args.log_folder, 'embedding_%d.png' % self.total_iter) plt.savefig(save_path) self.vae.train()
def train_ae(self, tr_data_dict): args = self.args train_loader = tr_data_dict['train_loader'] ###################################################################### ### start training total_iter = 0 best_loss = 1e8 best_err = 1e8 best_err_per = 1e8 begin_time = time() stop = 0 for epoch in range(args.max_epochs): epoch_ratio = float(epoch) / float(args.max_epochs) self.vae.train() for i, (lab_data, _) in enumerate(train_loader.get_iter()): lab_data = tensor2Var(lab_data) reconstruct, mean, _ = self.vae(lab_data) # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True) reconstruction_loss = ((reconstruct - lab_data)**2).mean() feature_loss = ((mean - 0)**2).mean() loss = reconstruction_loss + self.args.LAMBDA * feature_loss self.vae_opt.zero_grad() loss.backward() self.vae_opt.step() if args.save_vae: save_dict = { 'total_iter': total_iter, 'vae_state_dict': self.vae.state_dict(), 'vae_opt': self.vae_opt.state_dict() } torch.save(save_dict, self.vae_checkpoint) self.logger.info('epoch: %d, iter: %d, spent: %.3f s' % (epoch, total_iter, time() - begin_time)) self.logger.info( '[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f' % (loss.cpu().item(), reconstruction_loss.cpu().item(), feature_loss.cpu().item())) self.logger.info('--------') if epoch % 10 == 0: self.visualize_reconstruction(train_loader, epoch) # self.visualize_generation(epoch) begin_time = time() total_iter += 1 self.total_iter += 1
def train(self, tr_data_dict): args = self.args train_loader = tr_data_dict['train_loader'] p_d = tr_data_dict['p_d'] p_d_bar = tr_data_dict['p_d_bar'] p_d_2 = tr_data_dict['p_d_2'] ###################################################################### ### start training # if args.gan_checkpoint == "": # self.param_init_cnn(p_d_2) total_iter = 0 best_loss = 1e8 best_err = 1e8 best_err_per = 1e8 begin_time = time() stop = 0 for epoch in range(args.max_epochs): epoch_ratio = float(epoch) / float(args.max_epochs) self.dis.train() self.gen.train() self.vae.train() self.classifier.train() for i, (lab_data, _) in enumerate(train_loader.get_iter()): lab_data = tensor2Var(lab_data) # noise = create_noise(args.train_batch_size, args.noise_size) # noise = tensor2Var(noise) # gen_data = self.gen(noise).detach() # lab_feat = self.vae.get_features(lab_data) # # gen_feat = self.vae.get_features(gen_data) # gen_feat = p_d_bar.sample_feat(self.vae.get_features) # lab_logits = self.classifier(lab_feat) # gen_logits = self.classifier(gen_feat) # label_true = tensor2Var(torch.ones(lab_feat.shape[0])) # label_gen = tensor2Var(torch.zeros(gen_feat.shape[0])) # pred = torch.cat([lab_logits, gen_logits], dim=0) # label = torch.cat([label_true, label_gen], dim=0) # lab_loss = F.binary_cross_entropy(pred, label) # self.classifier_opt.zero_grad() # lab_loss.backward() # self.classifier_opt.step() reconstruct, mean, logvar = self.vae(lab_data) noise_in = tensor2Var( torch.FloatTensor(mean.size()).uniform_( -self.args.beta_1, self.args.beta_1)) noise_out = tensor2Var( torch.FloatTensor(mean.size()).uniform_( -self.args.beta_2, self.args.beta_2)) reconstruct = self.vae.decode( torch.clamp(mean + noise_in, -1.0, 1.0)) reconstruct_out = self.vae.decode( torch.clamp(mean + noise_out, -1.0, 1.0)) # reconstruct_gen, *_ = self.vae(gen_data) # reconstruction_loss = F.binary_cross_entropy(reconstruct, data_v, size_average=True) reconstruction_loss = ((reconstruct - lab_data)**2).mean() # gen_reconstruction_loss = torch.max( # tensor2Var(torch.zeros(gen_data.shape[0])), # 1.0 - ((reconstruct_gen - gen_data) ** 2).view(gen_data.shape[0], -1).mean(1)).mean() gen_reconstruction_loss = torch.max( tensor2Var(torch.zeros(lab_data.shape[0])), 0.1 - ((reconstruct_out - lab_data)**2).view( lab_data.shape[0], -1).mean(1)).mean() # gen_reconstruction_loss = ((reconstruct_gen - gen_data) ** 2).mean() # gen_reconstruction_loss = 0*((reconstruct - lab_data) ** 2).mean() feature_loss = ((mean - 0)**2).mean() loss = reconstruction_loss + self.args.LAMBDA * feature_loss + gen_reconstruction_loss # reconstruction_loss = ((reconstruct - lab_data) ** 2).mean() # kl_div = (-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1)).mean() # loss = reconstruction_loss + self.args.LAMBDA * kl_div + gen_reconstruction_loss self.vae_opt.zero_grad() loss.backward() self.vae_opt.step() ################## ## train the model # train all the networks # dis_dist = self.train_d(p_d, p_d_bar) # gen_critic = self.train_g() if args.save_vae: save_dict = { 'total_iter': total_iter, 'vae_state_dict': self.vae.state_dict(), 'vae_opt': self.vae_opt.state_dict() } torch.save(save_dict, self.vae_checkpoint) # save_dict = {'total_iter': total_iter, # 'classifier_state_dict': self.classifier.state_dict(), # 'classifier_opt': self.classifier_opt.state_dict()} # torch.save(save_dict, self.classifier_checkpoint) self.logger.info('epoch: %d, iter: %d, spent: %.3f s' % (epoch, total_iter, time() - begin_time)) self.logger.info( '[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f, \ gen_reconstruction_loss: %.4f' % (loss.cpu().item(), reconstruction_loss.cpu().item(), feature_loss.cpu().item(), gen_reconstruction_loss.cpu().item())) # self.logger.info('[train] loss: %.4f, reconst loss: %.4f, kl_div: %.4f, \ # gen_reconstruction_loss: %.4f' % (loss.cpu().item(), # reconstruction_loss.cpu().item(), kl_div.cpu().item(), gen_reconstruction_loss.cpu().item())) # self.logger.info('[train] loss: %.4f, reconst loss: %.4f, feature_loss: %.4f' % (loss.cpu().item(), # reconstruction_loss.cpu().item(), feature_loss.cpu().item())) # self.logger.info('%s: %.4f' % ('dis_dist', dis_dist)) # self.logger.info('classifier loss: %.4f' % lab_loss.cpu().item()) self.logger.info('--------') if epoch % 10 == 0: self.visualize_reconstruction(train_loader, epoch) # self.visualize() # self.visualize_embedding(p_d_2) # self.visualize_generation(epoch) begin_time = time() total_iter += 1 self.total_iter += 1
def test(args, vae_checkpoint, classifier_checkpoint, test_loader): os.environ['CUDA_VISIBLE_DEVICES'] = args.use_gpu # save_root = 'output' # save_folder = os.path.join(save_root, model_id) # if not os.path.exists(save_folder): # os.makedirs(save_folder) feature_size = args.feature_size if args.dataset == 'mnist': import mnist_model classifier = mnist_model.Classifier(args, feature_size) if args.feature_extractor == 'ae': vae = mnist_model.AE(args, feature_size) else: vae = mnist_model.VAE(args, feature_size) elif args.dataset == 'cifar': import cnn_model classifier = cnn_model.Classifier(args, feature_size) if args.feature_extractor == 'ae': vae = cnn_model.AE(args, feature_size) else: vae = cnn_model.VAE(args, feature_size) # checkpoint = torch.load(classifier_checkpoint, map_location=lambda storage, loc: storage) # classifier.load_state_dict(checkpoint['classifier_state_dict']) checkpoint = torch.load(vae_checkpoint, map_location=lambda storage, loc: storage) vae.load_state_dict(checkpoint['vae_state_dict']) if torch.cuda.is_available(): print('CUDA ensabled.') classifier.cuda() vae.cuda() for p in classifier.parameters(): p.requires_grad = False for p in vae.parameters(): p.requires_grad = False classifier.eval() vae.eval() x = [] y = [] emb = [] data_list = [] r_list = [] label_list = [] for i, (data, label) in enumerate(test_loader.get_iter(shuffle=False)): data = tensor2Var(data) # edge = tensor2Var(edge) feat = vae.get_features(data) # feat = vae.get_features(edge) emb.append(feat.squeeze().cpu().numpy()) r, mean, _ = vae(data) # r, mean, _ = vae(edge) # noise_in = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-args.beta_1, args.beta_1)) # noise_out = tensor2Var(torch.FloatTensor(mean.size()).uniform_(-args.beta_2, args.beta_2)) # r = vae.decode(torch.clamp(mean + noise_in, -1.0, 1.0)) # r_out = vae.decode(torch.clamp(mean + noise_out, -1.0, 1.0)) # pred = classifier(feat) data_list.append(data.cpu().numpy()) r_list.append(r.cpu().numpy()) label_list.append(label.cpu().numpy()) # if i == 0: # tv.utils.save_image(data * 0.5 + 0.5, # os.path.join(args.log_folder, 'test_origin.png'), nrow=10) # # tv.utils.save_image(edge * 0.5 + 0.5, # # os.path.join(args.log_folder, 'test_origin_edge.png'), nrow=10) # tv.utils.save_image(r * 0.5 + 0.5, # os.path.join(args.log_folder, 'test_reconstruct.png'), nrow=10) # tv.utils.save_image(r_out * 0.5 + 0.5, # os.path.join(args.log_folder, 'test_reconstruct_out.png'), nrow=10) # pred = -((r - data) ** 2).reshape(data.shape[0], -1).mean(1) pred = -((r - data)**2).reshape(data.shape[0], -1).mean(1) x.append(pred.cpu().numpy()) y.append(label.cpu().numpy()) data_list = np.vstack(data_list) r_list = np.vstack(r_list) label_list = np.hstack(label_list) test_data = [] test_r = [] # print(label_list.shape) pos = (label_list == 0) # print(pos.shape) test_data.append(data_list[pos][:90]) test_r.append(r_list[pos][:90]) pos = (label_list == 1) # print(pos.shape) test_data.append(data_list[pos][:10]) test_r.append(r_list[pos][:10]) test_data = torch.from_numpy(np.vstack(test_data)) test_r = torch.from_numpy(np.vstack(test_r)) # print(test_data.shape) tv.utils.save_image(test_data * 0.5 + 0.5, os.path.join(args.log_folder, 'test_origin.png'), nrow=10) # tv.utils.save_image(edge * 0.5 + 0.5, # os.path.join(args.log_folder, 'test_origin_edge.png'), nrow=10) tv.utils.save_image(test_r * 0.5 + 0.5, os.path.join(args.log_folder, 'test_reconstruct.png'), nrow=10) x = np.hstack(x) y = np.hstack(y) emb = np.vstack(emb) print(x) print(y) print((x < 0.9).sum()) # index = np.argsort(x) # x = x[index] # y = y[index] # print((y==0).sum(), (y==1).sum()) print(x.shape, y.shape) fpr, tpr, thresholds = sklearn.metrics.roc_curve(y, x, pos_label=1) auc = sklearn.metrics.auc(fpr, tpr) print(auc) with open(os.path.join(args.log_folder, 'auc.txt'), 'w') as f: f.write('%.4f\n' % auc) # print(np.sum(y)) ########################################################## # plot # tsne = sklearn.manifold.TSNE(2) # print(all_emb.shape) emb = emb[:1000] y = y[:1000] emb = tsne.fit_transform(emb) plt.clf() l_list = [] for i in range(2): pos = y == i l_list.append(plt.scatter(emb[pos, 0], emb[pos, 1])) plt.legend(l_list, ['novelty data', 'train data']) save_path = os.path.join(args.log_folder, 'test_embedding.png') plt.savefig(save_path)