class Ganomaly(object): """GANomaly Class """ @staticmethod def name(): """Return name of the class. """ return 'Ganomaly' def __init__(self, opt, dataloader=None): super(Ganomaly, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") # -- Discriminator attributes. self.out_d_real = None self.feat_real = None self.err_d_real = None self.fake = None self.latent_i = None self.latent_o = None self.out_d_fake = None self.feat_fake = None self.err_d_fake = None self.err_d = None # -- Generator attributes. self.out_g = None self.err_g_bce = None self.err_g_l1l = None self.err_g_enc = None self.err_g = None # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") # print(self.netg) # print(self.netd) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.l2l_criterion = l2_loss ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = 1 self.fake_label = 0 ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: with torch.no_grad(): self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netd.zero_grad() # -- # Train with real with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_d_real, self.feat_real = self.netd(self.input) # -- # Train with fake with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.fake_label) self.fake, self.latent_i, self.latent_o = self.netg(self.input) self.out_d_fake, self.feat_fake = self.netd(self.fake.detach()) # -- self.err_d = l2_loss(self.feat_real, self.feat_fake) self.err_d_real = self.err_d self.err_d_fake = self.err_d self.err_d.backward() self.optimizer_d.step() ## def reinitialize_netd(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.netg.zero_grad() with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_g, _ = self.netd(self.fake) self.err_g_bce = self.bce_criterion(self.out_g, self.label) self.err_g_l1l = self.l1l_criterion( self.fake, self.input) # constrain x' to look like x self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i) self.err_g = self.err_g_bce * self.opt.w_bce + self.err_g_l1l * self.opt.w_rec + self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) self.optimizer_g.step() ## def optimize(self): """ Optimize netD and netG networks. """ self.update_netd() self.update_netg() # If D loss is zero, then re-initialize netD if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5: self.reinitialize_netd() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_d_real', self.err_d_real.item()), ('err_d_fake', self.err_d_fake.item()), ('err_g_bce', self.err_g_bce.item()), ('err_g_l1l', self.err_g_l1l.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.dataloader['train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name()) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name().lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name()) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance
class SSnovelty(object): @staticmethod def name(): """Return name of the class. """ return 'SSnovelty' def __init__(self, opt): super(SSnovelty, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) # self.warmup = hyperparameters['model_specifics']['warmup'] self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") # -- Discriminator attributes. self.out_d_real = None self.feat_real = None self.err_d_real = None self.fake = None self.latent_i = None # self.latent_o = None self.out_d_fake = None self.feat_fake = None self.err_d_fake = None self.err_d = None self.idx = 0 self.opt.display = True # -- Generator attributes. self.out_g = None self.err_g_bce = None self.err_g_l1l = None self.err_g_enc = None self.err_g = None # -- Misc attributes self.epoch = 0 self.epoch1 = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) self.netc = Class(self.opt).to(self.device) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") # self.opt.iter = torch.load(os.path.join(self.opt.resume, 'netG.pth'))['epoch'] # self.netg.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) # self.netd.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) self.netc.load_state_dict( torch.load(os.path.join(self.opt.resume, 'class.pth'))['state_dict']) print("\tDone.\n") # print(self.netg) # print(self.netd) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.mse_criterion = nn.MSELoss() self.l2l_criterion = l2_loss self.loss_func = torch.nn.CrossEntropyLoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.input_1 = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.label_r = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(self.opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = 1 self.fake_label = 0 base = 1.0 sigma_list = [1, 2, 4, 8, 16] self.sigma_list = [sigma / base for sigma in sigma_list] ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.netc.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_c = optim.Adam(self.netc.parameters(), lr=self.opt.lr_c, betas=(self.opt.beta1, 0.999)) def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ self.input.data.resize_(input[0].size()).copy_(input[0]) self.gt.data.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.data.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netd.zero_grad() # -- # Train with real self.label.data.resize_(self.input.size(0)).fill_(self.real_label) self.out_d_real, self.feat_real = self.netd(self.input) # self.err_d_real = self.bce_criterion(self.out_d_real,self.label) # Train with fake self.label.data.resize_(self.input.size(0)).fill_(self.fake_label) self.fake, self.latent_i, = self.netg(self.input) self.out_d_fake, self.feat_fake = self.netd(self.fake.detach()) # self.err_d_fake = self.bce_criterion(self.out_d_fake, self.label) # -- # self.err_d = self.err_d_real + self.err_d_fake self.err_d = l2_loss(self.feat_real, self.feat_fake) # self.err_d = self.err_d_fake + self.err_d_l2 # self.err_d_real = self.err_d # self.err_d_fake = self.err_d self.err_d.backward() self.optimizer_d.step() ## def reinitialize_netd(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') def updata_netc(self): self.netc.zero_grad() output_real = self.netc(self.img_real) self.fake = self.netg(self.img_real) output_fake = self.netc(self.fake) self.err_c_fake = self.loss_func(output_fake, self.label_real) self.err_c_real = self.loss_func(output_real, self.label_real) self.err_c_dis = self.l2l_criterion(output_real, output_fake) self.err_c = self.err_c_fake + self.err_c_real + self.err_c_dis self.err_c.backward() self.optimizer_c.step() ## def trans_img(self, input): size = len(input) trans_map = torch.empty(size=(size, self.opt.nc, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) idx = int(size / 4) for i in range(idx): img = rotate_img_trans(input[i * 4 + 2], 1) trans_map[i * 4] = img img = rotate_img_trans(input[i * 4 + 3], 1) trans_map[i * 4 + 1] = img img = rotate_img_trans(input[i * 4], 1) trans_map[i * 4 + 2] = img img = rotate_img_trans(input[i * 4 + 1], 1) trans_map[i * 4 + 3] = img return trans_map def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.netg.zero_grad() # self.out_g, _ = self.netd(self.fake) # self.label.data.resize_(self.out_g.shape).fill_(self.real_label) # self.err_g_bce = self.bce_criterion(self.out_g, self.label) self.fake = self.netg(self.img_real) self.img_trans = self.trans_img(self.fake.detach().cpu()) self.err_g_r = self.mse_criterion(self.fake, self.img_trans) # self.err_g_l1l = self.mse_criterion(self.fake, self.img_real) # constrain x' to look like x # self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i) output_fake = self.netc(self.fake) output_real = self.netc(self.img_real) self.err_g_loss = self.l2l_criterion(output_fake, output_real) self.loss = self.loss_func(output_fake, self.label_real) # self.err_g = self.err_g_bce + self.err_g_l1l * self.opt.w_rec + (self.loss + self.err_g_loss) * self.opt.w_enc # self.err_g = self.err_g_bce + (loss + self.err_g_loss) * self.opt.w_enc # self.err_g = (self.loss + self.err_g_loss ) * self.opt.w_enc self.err_g = (self.err_g_r) * self.opt.w_rec + ( self.loss + self.err_g_loss) * self.opt.w_enc self.err_g.backward(retain_graph=True) self.optimizer_g.step() ## def argument_image_rotation_plus(self, X): size = len(X) self.img_real = torch.empty(size=(size * 4, self.opt.nc, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label_real = torch.empty(size=(size * 4, ), dtype=torch.long, device=self.device) for idx in range(size): img0 = X[idx] for i in range(4): [img, label] = rotate_img(img0, i) self.img_real[idx * 4 + i] = img self.label_real[idx * 4 + i] = label def optimize(self): """ Optimize netD and netG networks. """ self.argument_image_rotation_plus(self.input) self.input_img = self.input.to(self.device) self.updata_netc() # self.update_netd() self.update_netg() # If D loss is zero, then re-initialize netD # if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5: # self.reinitialize_netd() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ]) return errors def get_errors_1(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_c_real', self.err_c_real.item()), ('err_c_fake', self.err_c_fake.item()), ('err_c', self.err_c.item()), ]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.img_real.data fakes = self.fake.data trans = self.img_trans.data # fixed = self.netg(self.fixed_input)[0].data # fixed_input = self.fixed_input.data return reals, fakes, trans ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_step(self): self.netg.train() epoch_iter = 0 for step, (x, y, z) in enumerate(self.train_loader): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.input = Variable(x) self.optimize() if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, trans = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, trans) if self.opt.display: self.visualizer.display_current_images(reals, fakes, trans) # errors = self.get_errors() # if self.total_steps % self.opt.save_image_freq == 0: # reals, fakes, fixed , fixed_input = self.get_current_images() # self.visualizer.save_current_images(self.epoch, reals, fakes, fixed ) # if self.opt.display: # self.visualizer.display_current_images(reals, fakes, fixed,fixed_input) # print('Epoch %d err_g %f err_d %f err_c %f' % (self.epoch, self.err_g.item(),self.err_d.item(),self.err_c.item())) print('Epoch %d err_g %f err_c %f' % (self.epoch, self.err_g.item(), self.err_c.item())) # print('Epoch %d err_d_real %f err_d_fake %f ' % (self.epoch, self.err_d_real.item(), self.err_d_fake.item())) # print('Epoch %d err_g_bce %f err_g_loss %f err_g_l1 %f loss %f ' % (self.epoch, self.err_g_bce.item(), # self.err_g_loss.item(),self.err_g_l1l.item(),self.loss.item())) # print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch + 1, self.opt.niter)) def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = [0, 0, 0] # Train for niter epochs. print(">> Training model %s." % self.name()) train_data, test_data = cifa10Data(self.opt.normalclass) self.train_loader = DataLoader(train_data, batch_size=self.opt.batchsize, shuffle=True, num_workers=0, pin_memory=True) self.test_loader = DataLoader(test_data, batch_size=self.opt.batchsize, shuffle=False, num_workers=0, pin_memory=True) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_step() if self.epoch % 20 == 0: rec = self.test() if rec['AUC_R'] > best_auc[0]: best_auc[0] = rec['AUC_R'] if rec['AUC_C_real'] > best_auc[1]: best_auc[1] = rec['AUC_C_real'] if rec['AUC_C_fake'] > best_auc[2]: best_auc[2] = rec['AUC_C_fake'] # self.visualizer.print_current_performance(rec, best_auc) f = open('./output/testclass.txt', 'a', encoding='utf-8-sig', newline='\n') # f.write('rec: ' + str(rec['AUC_R'], ) + ', ' + str(rec['AUC_C_real'], ) + ',' + str(rec['AUC_C_fake'], ) + '\n') f.write('best' + str(best_auc) + '\n') f.close() # self.test_1() # self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) self.test() # self.test_1() ## def test(self): with torch.no_grad(): self.opt.load_weights = True self.epoch1 = 1 self.epoch2 = 200 self.total_steps = 0 epoch_iter = 0 print('test') label = torch.zeros(size=(10000, ), dtype=torch.long, device=self.device) pre = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) pre_real = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.distance = torch.zeros(size=(40000, ), dtype=torch.float32, device=self.device) self.relation_img = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.distance_img = torch.zeros(size=(40000, ), dtype=torch.float32, device=self.device) self.opt.phase = 'test' for i, (x, y, z) in enumerate(self.test_loader): self.input = Variable(x) self.label_r = Variable(z) self.total_steps += self.opt.batchsize self.argument_image_rotation_plus(self.input) self.label_r = self.label_r.to(self.device) classfiear_real_1 = self.netc(self.img_real) classfiear_real = F.softmax(classfiear_real_1, dim=1) prediction_real = -(torch.log(classfiear_real)) self.fake = self.netg(self.img_real) classfiear_1 = self.netc(self.fake) classfiear = F.softmax(classfiear_1, dim=1) prediction = -(torch.log(classfiear)) aaaa = (prediction.size(0) / 4) aaaa = int(aaaa) # prediction = prediction * (-1/4) label_z = torch.zeros(size=(aaaa, ), dtype=torch.long, device=self.device) pre_score = torch.zeros(size=(aaaa, ), dtype=prediction.dtype, device=self.device) pre_score_real = torch.zeros(size=(aaaa, ), dtype=prediction.dtype, device=self.device) self.img_trans = self.trans_img(self.fake.cpu()) distance_img = torch.mean( torch.pow((self.fake - self.img_real), 2), -1) distance_img = torch.mean(torch.mean(distance_img, -1), -1) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, trans = self.get_current_images() self.visualizer.save_test_images(i, reals, fakes, trans) if self.opt.display: self.visualizer.display_test_images( reals, fakes, trans) # distance = torch.mean(torch.pow((classfiear_1 - classfiear_real_1), 2), -1) # self.distance[i * self.opt.batchsize: i * self.opt.batchsize + distance.size(0)] = distance.reshape( # distance.size(0)) self.distance_img[i * 64:i * 64 + distance_img.size(0)] = distance_img.reshape( distance_img.size(0)) for k in range(aaaa): # label_z[k] = self.label_r[k * 4] pre_score[k] = (prediction[k * 4, 0] + prediction[k * 4 + 1, 1] + prediction[k * 4 + 2, 2] + prediction[k * 4 + 3, 3]) / 4 pre_score_real[k] = (prediction_real[k * 4, 0] + prediction_real[k * 4 + 1, 1] + prediction_real[k * 4 + 2, 2] + prediction_real[k * 4 + 3, 3]) / 4 label[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = self.label_r pre[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = pre_score pre_real[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = pre_score_real for j in range(10000): # self.relation[j] = self.distance[j*4] self.relation_img[j] = self.distance_img[j * 4] # D = pre + self.relation * 0.2 # D_real = pre_real + self.relation * 0.2 # # mu = torch.mul(pre, self.relation) # mu_real = torch.mul(pre_real, self.relation) aaaa = self.relation_img.cpu().numpy() np.savetxt('./output/log.txt', aaaa) bbbb = label.cpu().numpy() np.savetxt('./output/label.txt', bbbb) # auc_mu_fake = evaluate(label, mu, metric=self.opt.metric) # auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric) # auc_d_fake = evaluate(label, D, metric=self.opt.metric) # auc_d_real = evaluate(label, D_real, metric=self.opt.metric) auc_c_fake = evaluate(label, pre, metric=self.opt.metric) auc_c_real = evaluate(label, pre_real, metric=self.opt.metric) # auc_r = evaluate(label, self.relation, metric=self.opt.metric) auc_r_img = evaluate(label, self.relation_img, metric=self.opt.metric) performance = OrderedDict([('AUC_R', auc_r_img), ('AUC_C_real', auc_c_real), ('AUC_C_fake', auc_c_fake)]) print('test done') return performance # print('Train mul_real ROC AUC Score: %f mu_fake: %f' % (auc_mu_real, auc_mu_fake)) # print('Train add_real ROC AUC Score: %f add_fake: %f' % (auc_d_real, auc_d_fake)) def test_1(self): with torch.no_grad(): self.total_steps_test = 0 epoch_iter = 0 print('test') label = torch.zeros(size=(10000, ), dtype=torch.long, device=self.device) pre = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) pre_real = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation_img = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.classifiear = torch.zeros(size=(10000, 4), dtype=torch.float32, device=self.device) self.opt.phase = 'test' for i, (x, y, z) in enumerate(self.test_loader): self.input = Variable(x) self.label_rrr = Variable(z) self.input = self.input.to(self.device) self.label_rrr = self.label_rrr.to(self.device) size = int(self.input.size(0) / 4) input_1 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_2 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_3 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_4 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) classfiear_real_1 = self.netc(self.input) classfiear_real = F.softmax(classfiear_real_1, dim=1) prediction_real = -(torch.log(classfiear_real)) for j in range(size): input_1[j] = self.input[j * 4] input_2[j] = self.input[j * 4 + 1] input_3[j] = self.input[j * 4 + 2] input_4[j] = self.input[j * 4 + 3] output_1 = self.netg(input_1) output_2 = self.netg(input_2) output_3 = self.netg(input_3) output_4 = self.netg(input_4) classifiear_real = self.netc(input_1) classfiear_11 = self.netc(output_1) classfiear_21 = self.netc(output_2) classfiear_31 = self.netc(output_3) classfiear_41 = self.netc(output_4) classfiear_1 = F.softmax(classfiear_11, dim=1) classfiear_2 = F.softmax(classfiear_21, dim=1) classfiear_3 = F.softmax(classfiear_31, dim=1) classfiear_4 = F.softmax(classfiear_41, dim=1) prediction_1 = -(torch.log(classfiear_1)) prediction_2 = -(torch.log(classfiear_2)) prediction_3 = -(torch.log(classfiear_3)) prediction_4 = -(torch.log(classfiear_4)) aaaa = prediction_1.size(0) self.classifiear[i * 16:i * 16 + aaaa] = classfiear_11 # prediction = prediction * (-1/4) label_z = torch.zeros(size=(aaaa, ), dtype=torch.long, device=self.device) pre_score = torch.zeros(size=(aaaa, ), dtype=prediction_1.dtype, device=self.device) pre_score_real = torch.zeros(size=(aaaa, ), dtype=prediction_1.dtype, device=self.device) distance_img = torch.mean(torch.pow((output_1 - input_1), 2), -1) distance_img = torch.mean(torch.mean(distance_img, -1), -1) distance = torch.mean( torch.pow((classifiear_real - classfiear_11), 2), -1) self.relation[i * 16:i * 16 + distance.size(0)] = distance.reshape( distance.size(0)) self.relation_img[i * 16:i * 16 + distance.size(0)] = distance_img.reshape( distance.size(0)) for k in range(aaaa): label_z[k] = self.label_rrr[k * 4] pre_score[k] = (prediction_1[k, 0] + prediction_2[k, 1] + prediction_3[k, 2] + prediction_4[k, 3]) / 4 pre_score_real[k] = (prediction_real[k * 4, 0] + prediction_real[k * 4 + 1, 1] + prediction_real[k * 4 + 2, 2] + prediction_real[k * 4 + 3, 3]) / 4 label[i * 16:i * 16 + aaaa] = label_z pre[i * 16:i * 16 + aaaa] = pre_score pre_real[i * 16:i * 16 + aaaa] = pre_score_real D = pre + self.relation * 0.2 D_real = pre_real + self.relation * 0.2 aaaa = self.classifiear.cpu().numpy() np.savetxt('./output/log.txt', aaaa) bbbb = label.cpu().numpy() np.savetxt('./output/label.txt', bbbb) mu = torch.mul(pre, self.relation) mu_real = torch.mul(pre_real, self.relation) auc_mu_fake = evaluate(label, mu, metric=self.opt.metric) auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric) auc_d_fake = evaluate(label, D, metric=self.opt.metric) auc_d_real = evaluate(label, D_real, metric=self.opt.metric) auc_c_fake = evaluate(label, pre, metric=self.opt.metric) auc_c_real = evaluate(label, pre_real, metric=self.opt.metric) auc_r = evaluate(label, self.relation, metric=self.opt.metric) auc_r_img = evaluate(label, self.relation_img, metric=self.opt.metric) print('Train mul_real ROC AUC Score: %f mu_fake: %f' % (auc_mu_real, auc_mu_fake)) print('Train add_real ROC AUC Score: %f add_fake: %f' % (auc_d_real, auc_d_fake)) print('Train class_real ROC AUC Score: %f class_fake: %f' % (auc_c_real, auc_c_fake)) print('Train recon ROC AUC Score: %f recon_img:%f' % (auc_r, auc_r_img)) print('test done')
class Ganomaly(BaseModel): """GANomaly Class """ @property def name(self): return 'Ganomaly' def __init__(self, opt, dataloader): super(Ganomaly, self).__init__(opt, dataloader) # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") self.l_adv = l2_loss self.l_con = nn.L1Loss() self.l_enc = l2_loss self.l_bce = nn.BCELoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def forward_g(self): """ Forward propagate through netG """ #self.fake, self.latent_i, self.latent_o = self.netg(self.input) self.fake, self.latent_i, self.latent_o = self.netg(self.input) ## def forward_d(self): """ Forward propagate through netD """ self.pred_real, self.feat_real = self.netd(self.input) self.pred_fake, self.feat_fake = self.netd(self.fake.detach()) ## def backward_g(self): """ Backpropagate through netG """ self.err_g_adv = self.l_adv( self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \ self.err_g_con * self.opt.w_con + \ self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) ## def backward_d(self): """ Backpropagate through netD """ # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) # NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5 self.err_d.backward() ## def reinit_d(self): """ Re-initialize the weights of netD """ self.netd.apply(weights_init) print(' Reloading net d') def optimize_params(self): """ Forwardpass, Loss Computation and Backwardpass. """ # Forward-pass self.forward_g() self.forward_d() # Backward-pass # netg self.optimizer_g.zero_grad() self.backward_g() self.optimizer_g.step() # netd self.optimizer_d.zero_grad() self.backward_d() self.optimizer_d.step() if self.err_d.item() < 1e-5: self.reinit_d()
class Ganomaly(BaseModel): """GANomaly Class """ @property def name(self): return 'Ganomaly' def __init__(self, opt, dataloader): super(Ganomaly, self).__init__(opt, dataloader) # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) if self.opt.classifier: self.netc_i = NetC(self.opt).to(self.device) self.netc_o = NetC(self.opt).to(self.device) self.netc_i.apply(weights_init) self.netc_o.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") if self.opt.z_resume != '': print("\nLoading pre-trained z_networks.") self.opt.iter = torch.load( os.path.join(self.opt.z_resume, 'netC_i.pth'))['epoch'] self.netc_i.load_state_dict( torch.load(os.path.join(self.opt.z_resume, 'netC_i.pth'))['state_dict']) self.netc_o.load_state_dict( torch.load(os.path.join(self.opt.z_resume, 'netC_o.pth'))['state_dict']) print("\tDone.\n") self.l_adv = l2_loss self.l_con = nn.L1Loss() self.l_enc = l2_loss self.l_bce = nn.BCELoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Initialize input tensors for classifier self.i_input = torch.empty(size=(self.opt.batchsize, 1, self.sqrtnz, self.sqrtnz), dtype=torch.float32, device=self.device) self.o_input = torch.empty(size=(self.opt.batchsize, 1, int(self.opt.nz**0.5), int(self.opt.nz**0.5)), dtype=torch.float32, device=self.device) self.i_gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.o_gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.i_label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.o_label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.i_real_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.o_real_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) if self.opt.classifier: self.netc_i.train() self.netc_o.train() self.optimizer_i = optim.Adam(self.netc_i.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_o = optim.Adam(self.netc_o.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def forward_g(self): """ Forward propagate through netG """ self.fake, self.latent_i, self.latent_o = self.netg(self.input) ## def forward_d(self): """ Forward propagate through netD """ self.pred_real, self.feat_real = self.netd(self.input) self.pred_fake, self.feat_fake = self.netd(self.fake.detach()) ## def backward_g(self): """ Backpropagate through netG """ self.err_g_adv = self.l_adv( self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \ self.err_g_con * self.opt.w_con + \ self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) ## def backward_d(self): """ Backpropagate through netD """ # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) # NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5 self.err_d.backward() ## def reinit_d(self): """ Re-initialize the weights of netD """ self.netd.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net d') def optimize_params(self): """ Forwardpass, Loss Computation and Backwardpass. """ # Forward-pass self.forward_g() self.forward_d() # Backward-pass # netg self.optimizer_g.zero_grad() self.backward_g() self.optimizer_g.step() # netd self.optimizer_d.zero_grad() self.backward_d() self.optimizer_d.step() if self.err_d.item() < 1e-5: self.reinit_d() ## def save_weights_z(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netC_i.state_dict() }, '%s/netC_i.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netC_o.state_dict() }, '%s/netC_o.pth' % (weight_dir)) def forward_i(self): """ Forward propagate through netC_i """ self.pred_abn_i = self.netc_i(self.i_input) def forward_o(self): """ Forward propagate through netC_o """ self.pred_abn_o = self.netc_o(self.o_input) def backward_i(self): """ Backpropagate through netC_i """ # Real - Fake Loss self.err_i = self.l_bce(self.pred_abn_i, self.i_real_label) # NetD Loss & Backward-Pass self.err_i.backward() def backward_o(self): """ Backpropagate through netC_o """ # Real - Fake Loss self.err_o = self.l_bce(self.pred_abn_o, self.o_real_label) # NetD Loss & Backward-Pass self.err_o.backward() def reinit_i(self): """ Re-initialize the weights of netC_i """ self.netc_i.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net i') def reinit_o(self): """ Re-initialize the weights of netC_o """ self.netc_o.apply(weights_init) if (self.opt.strengthen != 1): print(' Reloading net o') def z_optimize_params(self, net): """ Forwardpass, Loss Computation and Backwardpass. """ if net == 'i': # Forward-pass self.forward_i() # Backward-pass # netc_i self.optimizer_i.zero_grad() self.backward_i() self.optimizer_i.step() if self.err_i.item() < 1e-5: self.reinit_i() if net == 'o': # Forward-pass self.forward_o() # Backward-pass # netc_o self.optimizer_o.zero_grad() self.backward_o() self.optimizer_o.step() if self.err_o.item() < 1e-5: self.reinit_o()