class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, dataloader): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") ## def set_input(self, input: torch.Tensor): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) #self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) if self.opt.dataset == 'mnist': abn_idx = torch.from_numpy((np.where( input[1].numpy() == int(self.opt.abnormal_class)))[0]) elif self.opt.dataset == 'fashionmnist': classes = { 'tshirt': 0, 'trouser': 1, 'pullover': 2, 'dress': 3, 'coat': 4, 'sandal': 5, 'shirt': 6, 'sneacker': 7, 'bag': 8, 'boot': 9 } abn_idx = torch.from_numpy((np.where( input[1].numpy() == int(classes[self.opt.abnormal_class])) )[0]) elif self.opt.dataset == 'svhn': abn_idx = torch.from_numpy((np.where( input[1].numpy() == int(self.opt.abnormal_class)))[0]) else: classes = { 'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9 } abn_idx = torch.from_numpy((np.where( input[1].numpy() == int(classes[self.opt.abnormal_class])) )[0]) test = torch.zeros(input[1].size()) test[abn_idx] = 1 self.gt = test # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ #weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') #if not os.path.exists(weight_dir): os.makedirs(weight_dir) weight_dir = os.path.join(self.opt.outf, self.opt.name, 'test', 'abnormal' + str(self.opt.abnormal_class), 'seed' + str(self.opt.manualseed)) if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_one_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) # self.optimize() self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.dataloader['train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_one_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] # self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) self.save_weights(self.epoch) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) #self.fake, latent_i, latent_o = self.netg(self.input) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance def train_final(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): #self.netg.eval() #self.dataloader=dataL #self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['train'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['train'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['train'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['train'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['train'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) recon_1 = self.latent_i - self.latent_o return recon_1 def test_final(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): #self.dataloader=dataL #self.netg.eval() # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.gt_labels_original = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 #del self.netg.grad_dict['final-256-1-conv']['grad_output'] #self.netg.grad_dict['final-256-1-conv']['grad_output']=[] for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) #self.fake, latent_i, latent_o = self.netg(self.input) #/#self.fake, latent_i, latent_o, latent_o_p, latent_o_pp, latent_o_ppp = self.netg(self.input) #self.fake, latent_i, latent_o = self.netg(self.input) self.fake, latent_i, latent_o = self.netg(self.input) if (i == 0): test_path = os.path.join( self.opt.outf, self.opt.dataset, 'test', 'OCSVM', 'abnormal' + str(self.opt.abnormal_class)) print(test_path) if not os.path.isdir(test_path): os.makedirs(test_path) test_path = os.path.join( self.opt.outf, self.opt.dataset, 'test', 'OCSVM', 'abnormal' + str(self.opt.abnormal_class), 'reconstructed' + str(self.opt.abnormal_class) + '.png') viz.viz_batch_im(np.squeeze(np.array(self.fake.cpu())), grid_size=[7, 7], save_path=test_path, gap=0, gap_color=0, shuffle=False) test_path = os.path.join( self.opt.outf, self.opt.dataset, 'test', 'OCSVM', 'abnormal' + str(self.opt.abnormal_class), 'img' + str(self.opt.abnormal_class) + '.png') viz.viz_batch_im(np.squeeze(np.array(self.input.cpu())), grid_size=[7, 7], save_path=test_path, gap=0, gap_color=0, shuffle=False) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.gt_labels_original[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.label.reshape( error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) auroc_value, auprc_value, _ = evaluate_final( self.gt_labels.cpu().detach(), self.an_scores.cpu().detach()) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUROC', auroc_value), ('AUPRC', auprc_value)]) #self.visualizer.print_final_performance(performance) recon_1 = self.latent_i - self.latent_o y_true = self.gt_labels y_true_original = self.gt_labels_original return recon_1, y_true, y_true_original, auroc_value, auprc_value
class Ganomaly(object): """GANomaly Class """ @staticmethod def name(): """Return name of the class. """ return 'Ganomaly' def __init__(self, opt, dataloader=None): super(Ganomaly, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") # -- Discriminator attributes. self.out_d_real = None self.feat_real = None self.err_d_real = None self.fake = None self.latent_i = None self.latent_o = None self.out_d_fake = None self.feat_fake = None self.err_d_fake = None self.err_d = None # -- Generator attributes. self.out_g = None self.err_g_bce = None self.err_g_l1l = None self.err_g_enc = None self.err_g = None # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") # print(self.netg) # print(self.netd) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.l2l_criterion = l2_loss ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = 1 self.fake_label = 0 ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ## def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: with torch.no_grad(): self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netd.zero_grad() # -- # Train with real with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_d_real, self.feat_real = self.netd(self.input) # -- # Train with fake with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.fake_label) self.fake, self.latent_i, self.latent_o = self.netg(self.input) self.out_d_fake, self.feat_fake = self.netd(self.fake.detach()) # -- self.err_d = l2_loss(self.feat_real, self.feat_fake) self.err_d_real = self.err_d self.err_d_fake = self.err_d self.err_d.backward() self.optimizer_d.step() ## def reinitialize_netd(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.netg.zero_grad() with torch.no_grad(): self.label.resize_(self.opt.batchsize).fill_(self.real_label) self.out_g, _ = self.netd(self.fake) self.err_g_bce = self.bce_criterion(self.out_g, self.label) self.err_g_l1l = self.l1l_criterion( self.fake, self.input) # constrain x' to look like x self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i) self.err_g = self.err_g_bce * self.opt.w_bce + self.err_g_l1l * self.opt.w_rec + self.err_g_enc * self.opt.w_enc self.err_g.backward(retain_graph=True) self.optimizer_g.step() ## def optimize(self): """ Optimize netD and netG networks. """ self.update_netd() self.update_netg() # If D loss is zero, then re-initialize netD if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5: self.reinitialize_netd() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_d_real', self.err_d_real.item()), ('err_d_fake', self.err_d_fake.item()), ('err_g_bce', self.err_g_bce.item()), ('err_g_l1l', self.err_g_l1l.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.dataloader['train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name()) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name().lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name()) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance
class SSnovelty(object): @staticmethod def name(): """Return name of the class. """ return 'SSnovelty' def __init__(self, opt): super(SSnovelty, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) # self.warmup = hyperparameters['model_specifics']['warmup'] self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") # -- Discriminator attributes. self.out_d_real = None self.feat_real = None self.err_d_real = None self.fake = None self.latent_i = None # self.latent_o = None self.out_d_fake = None self.feat_fake = None self.err_d_fake = None self.err_d = None self.idx = 0 self.opt.display = True # -- Generator attributes. self.out_g = None self.err_g_bce = None self.err_g_l1l = None self.err_g_enc = None self.err_g = None # -- Misc attributes self.epoch = 0 self.epoch1 = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. self.netg = NetG(self.opt).to(self.device) self.netd = NetD(self.opt).to(self.device) self.netg.apply(weights_init) self.netd.apply(weights_init) self.netc = Class(self.opt).to(self.device) ## if self.opt.resume != '': print("\nLoading pre-trained networks.") # self.opt.iter = torch.load(os.path.join(self.opt.resume, 'netG.pth'))['epoch'] # self.netg.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) # self.netd.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) self.netc.load_state_dict( torch.load(os.path.join(self.opt.resume, 'class.pth'))['state_dict']) print("\tDone.\n") # print(self.netg) # print(self.netd) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.mse_criterion = nn.MSELoss() self.l2l_criterion = l2_loss self.loss_func = torch.nn.CrossEntropyLoss() ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.input_1 = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.label_r = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(self.opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = 1 self.fake_label = 0 base = 1.0 sigma_list = [1, 2, 4, 8, 16] self.sigma_list = [sigma / base for sigma in sigma_list] ## # Setup optimizer if self.opt.isTrain: self.netg.train() self.netd.train() self.netc.train() self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_c = optim.Adam(self.netc.parameters(), lr=self.opt.lr_c, betas=(self.opt.beta1, 0.999)) def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ self.input.data.resize_(input[0].size()).copy_(input[0]) self.gt.data.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.data.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netd.zero_grad() # -- # Train with real self.label.data.resize_(self.input.size(0)).fill_(self.real_label) self.out_d_real, self.feat_real = self.netd(self.input) # self.err_d_real = self.bce_criterion(self.out_d_real,self.label) # Train with fake self.label.data.resize_(self.input.size(0)).fill_(self.fake_label) self.fake, self.latent_i, = self.netg(self.input) self.out_d_fake, self.feat_fake = self.netd(self.fake.detach()) # self.err_d_fake = self.bce_criterion(self.out_d_fake, self.label) # -- # self.err_d = self.err_d_real + self.err_d_fake self.err_d = l2_loss(self.feat_real, self.feat_fake) # self.err_d = self.err_d_fake + self.err_d_l2 # self.err_d_real = self.err_d # self.err_d_fake = self.err_d self.err_d.backward() self.optimizer_d.step() ## def reinitialize_netd(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') def updata_netc(self): self.netc.zero_grad() output_real = self.netc(self.img_real) self.fake = self.netg(self.img_real) output_fake = self.netc(self.fake) self.err_c_fake = self.loss_func(output_fake, self.label_real) self.err_c_real = self.loss_func(output_real, self.label_real) self.err_c_dis = self.l2l_criterion(output_real, output_fake) self.err_c = self.err_c_fake + self.err_c_real + self.err_c_dis self.err_c.backward() self.optimizer_c.step() ## def trans_img(self, input): size = len(input) trans_map = torch.empty(size=(size, self.opt.nc, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) idx = int(size / 4) for i in range(idx): img = rotate_img_trans(input[i * 4 + 2], 1) trans_map[i * 4] = img img = rotate_img_trans(input[i * 4 + 3], 1) trans_map[i * 4 + 1] = img img = rotate_img_trans(input[i * 4], 1) trans_map[i * 4 + 2] = img img = rotate_img_trans(input[i * 4 + 1], 1) trans_map[i * 4 + 3] = img return trans_map def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.netg.zero_grad() # self.out_g, _ = self.netd(self.fake) # self.label.data.resize_(self.out_g.shape).fill_(self.real_label) # self.err_g_bce = self.bce_criterion(self.out_g, self.label) self.fake = self.netg(self.img_real) self.img_trans = self.trans_img(self.fake.detach().cpu()) self.err_g_r = self.mse_criterion(self.fake, self.img_trans) # self.err_g_l1l = self.mse_criterion(self.fake, self.img_real) # constrain x' to look like x # self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i) output_fake = self.netc(self.fake) output_real = self.netc(self.img_real) self.err_g_loss = self.l2l_criterion(output_fake, output_real) self.loss = self.loss_func(output_fake, self.label_real) # self.err_g = self.err_g_bce + self.err_g_l1l * self.opt.w_rec + (self.loss + self.err_g_loss) * self.opt.w_enc # self.err_g = self.err_g_bce + (loss + self.err_g_loss) * self.opt.w_enc # self.err_g = (self.loss + self.err_g_loss ) * self.opt.w_enc self.err_g = (self.err_g_r) * self.opt.w_rec + ( self.loss + self.err_g_loss) * self.opt.w_enc self.err_g.backward(retain_graph=True) self.optimizer_g.step() ## def argument_image_rotation_plus(self, X): size = len(X) self.img_real = torch.empty(size=(size * 4, self.opt.nc, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label_real = torch.empty(size=(size * 4, ), dtype=torch.long, device=self.device) for idx in range(size): img0 = X[idx] for i in range(4): [img, label] = rotate_img(img0, i) self.img_real[idx * 4 + i] = img self.label_real[idx * 4 + i] = label def optimize(self): """ Optimize netD and netG networks. """ self.argument_image_rotation_plus(self.input) self.input_img = self.input.to(self.device) self.updata_netc() # self.update_netd() self.update_netg() # If D loss is zero, then re-initialize netD # if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5: # self.reinitialize_netd() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ]) return errors def get_errors_1(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_c_real', self.err_c_real.item()), ('err_c_fake', self.err_c_fake.item()), ('err_c', self.err_c.item()), ]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.img_real.data fakes = self.fake.data trans = self.img_trans.data # fixed = self.netg(self.fixed_input)[0].data # fixed_input = self.fixed_input.data return reals, fakes, trans ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_step(self): self.netg.train() epoch_iter = 0 for step, (x, y, z) in enumerate(self.train_loader): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.input = Variable(x) self.optimize() if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, trans = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, trans) if self.opt.display: self.visualizer.display_current_images(reals, fakes, trans) # errors = self.get_errors() # if self.total_steps % self.opt.save_image_freq == 0: # reals, fakes, fixed , fixed_input = self.get_current_images() # self.visualizer.save_current_images(self.epoch, reals, fakes, fixed ) # if self.opt.display: # self.visualizer.display_current_images(reals, fakes, fixed,fixed_input) # print('Epoch %d err_g %f err_d %f err_c %f' % (self.epoch, self.err_g.item(),self.err_d.item(),self.err_c.item())) print('Epoch %d err_g %f err_c %f' % (self.epoch, self.err_g.item(), self.err_c.item())) # print('Epoch %d err_d_real %f err_d_fake %f ' % (self.epoch, self.err_d_real.item(), self.err_d_fake.item())) # print('Epoch %d err_g_bce %f err_g_loss %f err_g_l1 %f loss %f ' % (self.epoch, self.err_g_bce.item(), # self.err_g_loss.item(),self.err_g_l1l.item(),self.loss.item())) # print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch + 1, self.opt.niter)) def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = [0, 0, 0] # Train for niter epochs. print(">> Training model %s." % self.name()) train_data, test_data = cifa10Data(self.opt.normalclass) self.train_loader = DataLoader(train_data, batch_size=self.opt.batchsize, shuffle=True, num_workers=0, pin_memory=True) self.test_loader = DataLoader(test_data, batch_size=self.opt.batchsize, shuffle=False, num_workers=0, pin_memory=True) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_step() if self.epoch % 20 == 0: rec = self.test() if rec['AUC_R'] > best_auc[0]: best_auc[0] = rec['AUC_R'] if rec['AUC_C_real'] > best_auc[1]: best_auc[1] = rec['AUC_C_real'] if rec['AUC_C_fake'] > best_auc[2]: best_auc[2] = rec['AUC_C_fake'] # self.visualizer.print_current_performance(rec, best_auc) f = open('./output/testclass.txt', 'a', encoding='utf-8-sig', newline='\n') # f.write('rec: ' + str(rec['AUC_R'], ) + ', ' + str(rec['AUC_C_real'], ) + ',' + str(rec['AUC_C_fake'], ) + '\n') f.write('best' + str(best_auc) + '\n') f.close() # self.test_1() # self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) self.test() # self.test_1() ## def test(self): with torch.no_grad(): self.opt.load_weights = True self.epoch1 = 1 self.epoch2 = 200 self.total_steps = 0 epoch_iter = 0 print('test') label = torch.zeros(size=(10000, ), dtype=torch.long, device=self.device) pre = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) pre_real = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.distance = torch.zeros(size=(40000, ), dtype=torch.float32, device=self.device) self.relation_img = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.distance_img = torch.zeros(size=(40000, ), dtype=torch.float32, device=self.device) self.opt.phase = 'test' for i, (x, y, z) in enumerate(self.test_loader): self.input = Variable(x) self.label_r = Variable(z) self.total_steps += self.opt.batchsize self.argument_image_rotation_plus(self.input) self.label_r = self.label_r.to(self.device) classfiear_real_1 = self.netc(self.img_real) classfiear_real = F.softmax(classfiear_real_1, dim=1) prediction_real = -(torch.log(classfiear_real)) self.fake = self.netg(self.img_real) classfiear_1 = self.netc(self.fake) classfiear = F.softmax(classfiear_1, dim=1) prediction = -(torch.log(classfiear)) aaaa = (prediction.size(0) / 4) aaaa = int(aaaa) # prediction = prediction * (-1/4) label_z = torch.zeros(size=(aaaa, ), dtype=torch.long, device=self.device) pre_score = torch.zeros(size=(aaaa, ), dtype=prediction.dtype, device=self.device) pre_score_real = torch.zeros(size=(aaaa, ), dtype=prediction.dtype, device=self.device) self.img_trans = self.trans_img(self.fake.cpu()) distance_img = torch.mean( torch.pow((self.fake - self.img_real), 2), -1) distance_img = torch.mean(torch.mean(distance_img, -1), -1) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, trans = self.get_current_images() self.visualizer.save_test_images(i, reals, fakes, trans) if self.opt.display: self.visualizer.display_test_images( reals, fakes, trans) # distance = torch.mean(torch.pow((classfiear_1 - classfiear_real_1), 2), -1) # self.distance[i * self.opt.batchsize: i * self.opt.batchsize + distance.size(0)] = distance.reshape( # distance.size(0)) self.distance_img[i * 64:i * 64 + distance_img.size(0)] = distance_img.reshape( distance_img.size(0)) for k in range(aaaa): # label_z[k] = self.label_r[k * 4] pre_score[k] = (prediction[k * 4, 0] + prediction[k * 4 + 1, 1] + prediction[k * 4 + 2, 2] + prediction[k * 4 + 3, 3]) / 4 pre_score_real[k] = (prediction_real[k * 4, 0] + prediction_real[k * 4 + 1, 1] + prediction_real[k * 4 + 2, 2] + prediction_real[k * 4 + 3, 3]) / 4 label[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = self.label_r pre[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = pre_score pre_real[i * self.opt.batchsize:i * self.opt.batchsize + aaaa] = pre_score_real for j in range(10000): # self.relation[j] = self.distance[j*4] self.relation_img[j] = self.distance_img[j * 4] # D = pre + self.relation * 0.2 # D_real = pre_real + self.relation * 0.2 # # mu = torch.mul(pre, self.relation) # mu_real = torch.mul(pre_real, self.relation) aaaa = self.relation_img.cpu().numpy() np.savetxt('./output/log.txt', aaaa) bbbb = label.cpu().numpy() np.savetxt('./output/label.txt', bbbb) # auc_mu_fake = evaluate(label, mu, metric=self.opt.metric) # auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric) # auc_d_fake = evaluate(label, D, metric=self.opt.metric) # auc_d_real = evaluate(label, D_real, metric=self.opt.metric) auc_c_fake = evaluate(label, pre, metric=self.opt.metric) auc_c_real = evaluate(label, pre_real, metric=self.opt.metric) # auc_r = evaluate(label, self.relation, metric=self.opt.metric) auc_r_img = evaluate(label, self.relation_img, metric=self.opt.metric) performance = OrderedDict([('AUC_R', auc_r_img), ('AUC_C_real', auc_c_real), ('AUC_C_fake', auc_c_fake)]) print('test done') return performance # print('Train mul_real ROC AUC Score: %f mu_fake: %f' % (auc_mu_real, auc_mu_fake)) # print('Train add_real ROC AUC Score: %f add_fake: %f' % (auc_d_real, auc_d_fake)) def test_1(self): with torch.no_grad(): self.total_steps_test = 0 epoch_iter = 0 print('test') label = torch.zeros(size=(10000, ), dtype=torch.long, device=self.device) pre = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) pre_real = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.relation_img = torch.zeros(size=(10000, ), dtype=torch.float32, device=self.device) self.classifiear = torch.zeros(size=(10000, 4), dtype=torch.float32, device=self.device) self.opt.phase = 'test' for i, (x, y, z) in enumerate(self.test_loader): self.input = Variable(x) self.label_rrr = Variable(z) self.input = self.input.to(self.device) self.label_rrr = self.label_rrr.to(self.device) size = int(self.input.size(0) / 4) input_1 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_2 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_3 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) input_4 = torch.empty(size=(size, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) classfiear_real_1 = self.netc(self.input) classfiear_real = F.softmax(classfiear_real_1, dim=1) prediction_real = -(torch.log(classfiear_real)) for j in range(size): input_1[j] = self.input[j * 4] input_2[j] = self.input[j * 4 + 1] input_3[j] = self.input[j * 4 + 2] input_4[j] = self.input[j * 4 + 3] output_1 = self.netg(input_1) output_2 = self.netg(input_2) output_3 = self.netg(input_3) output_4 = self.netg(input_4) classifiear_real = self.netc(input_1) classfiear_11 = self.netc(output_1) classfiear_21 = self.netc(output_2) classfiear_31 = self.netc(output_3) classfiear_41 = self.netc(output_4) classfiear_1 = F.softmax(classfiear_11, dim=1) classfiear_2 = F.softmax(classfiear_21, dim=1) classfiear_3 = F.softmax(classfiear_31, dim=1) classfiear_4 = F.softmax(classfiear_41, dim=1) prediction_1 = -(torch.log(classfiear_1)) prediction_2 = -(torch.log(classfiear_2)) prediction_3 = -(torch.log(classfiear_3)) prediction_4 = -(torch.log(classfiear_4)) aaaa = prediction_1.size(0) self.classifiear[i * 16:i * 16 + aaaa] = classfiear_11 # prediction = prediction * (-1/4) label_z = torch.zeros(size=(aaaa, ), dtype=torch.long, device=self.device) pre_score = torch.zeros(size=(aaaa, ), dtype=prediction_1.dtype, device=self.device) pre_score_real = torch.zeros(size=(aaaa, ), dtype=prediction_1.dtype, device=self.device) distance_img = torch.mean(torch.pow((output_1 - input_1), 2), -1) distance_img = torch.mean(torch.mean(distance_img, -1), -1) distance = torch.mean( torch.pow((classifiear_real - classfiear_11), 2), -1) self.relation[i * 16:i * 16 + distance.size(0)] = distance.reshape( distance.size(0)) self.relation_img[i * 16:i * 16 + distance.size(0)] = distance_img.reshape( distance.size(0)) for k in range(aaaa): label_z[k] = self.label_rrr[k * 4] pre_score[k] = (prediction_1[k, 0] + prediction_2[k, 1] + prediction_3[k, 2] + prediction_4[k, 3]) / 4 pre_score_real[k] = (prediction_real[k * 4, 0] + prediction_real[k * 4 + 1, 1] + prediction_real[k * 4 + 2, 2] + prediction_real[k * 4 + 3, 3]) / 4 label[i * 16:i * 16 + aaaa] = label_z pre[i * 16:i * 16 + aaaa] = pre_score pre_real[i * 16:i * 16 + aaaa] = pre_score_real D = pre + self.relation * 0.2 D_real = pre_real + self.relation * 0.2 aaaa = self.classifiear.cpu().numpy() np.savetxt('./output/log.txt', aaaa) bbbb = label.cpu().numpy() np.savetxt('./output/label.txt', bbbb) mu = torch.mul(pre, self.relation) mu_real = torch.mul(pre_real, self.relation) auc_mu_fake = evaluate(label, mu, metric=self.opt.metric) auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric) auc_d_fake = evaluate(label, D, metric=self.opt.metric) auc_d_real = evaluate(label, D_real, metric=self.opt.metric) auc_c_fake = evaluate(label, pre, metric=self.opt.metric) auc_c_real = evaluate(label, pre_real, metric=self.opt.metric) auc_r = evaluate(label, self.relation, metric=self.opt.metric) auc_r_img = evaluate(label, self.relation_img, metric=self.opt.metric) print('Train mul_real ROC AUC Score: %f mu_fake: %f' % (auc_mu_real, auc_mu_fake)) print('Train add_real ROC AUC Score: %f add_fake: %f' % (auc_d_real, auc_d_fake)) print('Train class_real ROC AUC Score: %f class_fake: %f' % (auc_c_real, auc_c_fake)) print('Train recon ROC AUC Score: %f recon_img:%f' % (auc_r, auc_r_img)) print('test done')
class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, dataloader): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu") ## def set_input(self, input:torch.Tensor): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()}, '%s/netG.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()}, '%s/netD.pth' % (weight_dir)) ## def train_one_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) # self.optimize() self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset) self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_one_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1) time_o = time.time() self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz) self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance
class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, data): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.data = data self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu") ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def set_input(self, input:torch.Tensor, noise:bool=False): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Add noise to the input. if noise: self.noise.data.copy_(torch.randn(self.noise.size())) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_lat', self.err_g_lat.item())]) return errors ## def reinit_d(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch:int, is_best:bool=False): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join( self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) if is_best: torch.save({'epoch': epoch, 'state_dict': self.netg.state_dict()}, f'{weight_dir}/netG_best.pth') torch.save({'epoch': epoch, 'state_dict': self.netd.state_dict()}, f'{weight_dir}/netD_best.pth') else: torch.save({'epoch': epoch, 'state_dict': self.netd.state_dict()}, f"{weight_dir}/netD_{epoch}.pth") torch.save({'epoch': epoch, 'state_dict': self.netg.state_dict()}, f"{weight_dir}/netG_{epoch}.pth") def load_weights(self, epoch=None, is_best:bool=False, path=None): """ Load pre-trained weights of NetG and NetD Keyword Arguments: epoch {int} -- Epoch to be loaded (default: {None}) is_best {bool} -- Load the best epoch (default: {False}) path {str} -- Path to weight file (default: {None}) Raises: Exception -- [description] IOError -- [description] """ if epoch is None and is_best is False: raise Exception('Please provide epoch to be loaded or choose the best epoch.') if is_best: fname_g = f"netG_best.pth" fname_d = f"netD_best.pth" else: fname_g = f"netG_{epoch}.pth" fname_d = f"netD_{epoch}.pth" if path is None: path_g = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_g}" path_d = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_d}" # Load the weights of netg and netd. print('>> Loading weights...') weights_g = torch.load(path_g)['state_dict'] weights_d = torch.load(path_d)['state_dict'] try: self.netg.load_state_dict(weights_g) self.netd.load_state_dict(weights_d) except IOError: raise IOError("netG weights not found") print(' Done.') ## def train_one_epoch(self): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 for data in tqdm(self.data.train, leave=False, total=len(self.data.train)): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len(self.data.train.dataset) self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter)) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(f">> Training {self.name} on {self.opt.dataset} to detect {self.opt.abnormal_class}") for self.epoch in range(self.opt.iter, self.opt.niter): self.train_one_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) ## def test(self): """ Test GANomaly model. Args: data ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len(self.data.valid.dataset),), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len(self.data.valid.dataset),), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.data.valid, 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1) time_o = time.time() self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz) self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores)) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len(self.data.valid.dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance ## def update_learning_rate(self): """ Update learning rate based on the rule provided in options. """ for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print(' LR = %.7f' % lr)
class ocgan(object): """GANomaly Class """ @staticmethod def name(): """Return name of the class. """ return 'ocgan' def __init__(self, opt, dataloader=None): super(ocgan, self).__init__() ## # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') # -- Discriminator attributes. self.out_dv0=None self.out_dv1=None self.out_dv=None self.label_dv=None self.err_dv_bce=None self.out_dl0=None self.out_dl1=None self.out_dl=None self.label_dl=None self.err_dl_bce=None self.err_d=None # -- Generator attributes. self.out_gv0 = None self.out_gv1 = None self.out_gv = None self.label_gv = None self.err_gv_bce = None self.out_gl0 = None self.out_gl1 = None self.out_gl = None self.label_gl = None self.err_gl_bce = None self.err_g_mse =None self.err_g = None # -- Classfier attribute self.out_c0=None self.out_c1=None self.out_c=None self.label_c=None self.err_c_bce = None # -- mine attribute self.out_m = None self.err_m_bce = None # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks. print('Create and initialize networks.') self.neten=Encoder(self.opt.isize,self.opt.nc,self.opt.ndf) self.netde=Decoder(self.opt.isize,self.opt.nc,self.opt.ngf) self.netdl=Dl(self.opt.nz) self.netdv=Dv(self.opt) self.netc=Classfier(self.opt) self.netdv.apply(weights_init) self.netc.apply(weights_init) self.neten.apply(weights_init) self.netde.apply(weights_init) print('end') ## if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load(os.path.join(self.opt.resume, 'neten.pth'))['epoch'] self.neten.load_state_dict(torch.load(os.path.join(self.opt.resume, 'neten.pth'))['state_dict']) self.netde.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netde.pth'))['state_dict']) self.netdl.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netdl.pth'))['state_dict']) self.netdv.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netdv.pth'))['state_dict']) self.netc.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netC.pth'))['state_dict']) print("\tDone.\n") print(self.neten) print(self.netde) print(self.netdl) print(self.netdv) print(self.netc) ## # Loss Functions self.bce_criterion = nn.BCELoss() self.l1l_criterion = nn.L1Loss() self.l2l_criterion = nn.MSELoss() ## # Initialize input tensors. print('Initialize input tensors.') self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32) self.label = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32) self.labelf = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32) self.gt = torch.empty(size=(opt.batchsize,), dtype=torch.long) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32) self.real_label = 1 self.fake_label = 0 self.u=None self.n=None self.l1=None self.l2=torch.empty(size=(self.opt.batchsize, self.opt.nz,1,1), dtype=torch.float32) self.del1=None self.del2=None print('end') ## # Setup optimizer print('Setup optimizer') if self.opt.isTrain: self.neten.train() self.netde.train() self.netdv.train() self.netdl.train() self.netc.train() self.optimizer_en = optim.Adam(self.neten.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_de = optim.Adam(self.netde.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_dl = optim.Adam(self.netdl.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_dv = optim.Adam(self.netdv.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_c = optim.Adam(self.netc.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_l2 = optim.Adam([{'params':self.l2}], lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) print('end') ## def set_input(self, input): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ self.input.data.resize_(input[0].size()).copy_(input[0]) self.gt.data.resize_(input[1].size()).copy_(input[1]) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.data.resize_(input[0].size()).copy_(input[0]) ## def update_netd(self): """ Update D network: Ladv = |f(real) - f(fake)|_2 """ ## # Feature Matching. self.netdv.zero_grad() self.netdl.zero_grad() # -- self.out_dv0 = self.netdv(self.del2.detach()) self.out_dv1 = self.netdv(self.input) self.out_dv = torch.cat([self.out_dv0, self.out_dv1], 0) self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.label.data.resize_(self.opt.batchsize).fill_(self.real_label) self.label_dv = torch.cat([self.labelf, self.label], 0) self.err_dv_bce = self.bce_criterion(self.out_dv, self.label_dv) self.out_dl0 = self.netdl(self.l1.detach()) self.out_dl1 = self.netdl(self.l2) self.out_dl = torch.cat([self.out_dl0, self.out_dl1], 0) self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.label.data.resize_(self.opt.batchsize).fill_(self.real_label) self.label_dl = torch.cat([self.labelf, self.label], 0) self.err_dl_bce = self.bce_criterion(self.out_dl, self.label_dl) self.err_d=self.err_dl_bce+self.err_dv_bce self.err_d.backward(retain_graph=True) self.optimizer_dv.step() self.optimizer_dl.step() ## def update_netg(self): """ # ============================================================ # # (2) Update G network: log(D(G(x))) + ||G(x) - x|| # # ============================================================ # """ self.neten.zero_grad() self.netde.zero_grad() # -- # self.out_gv0 = self.netdv(self.input) # self.out_gv1 = self.netdv(self.del2) # self.out_gv = torch.cat([self.out_gv0, self.out_gv1], 0) self.out_gv1 = self.netdv(self.del2) # self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.label.data.resize_(self.opt.batchsize).fill_(self.real_label) # self.label_gv = torch.cat([self.labelf, self.label], 0) # self.err_gv_bce = self.bce_criterion(self.out_gv, self.label_gv) self.err_gv_bce = self.bce_criterion(self.out_gv1, self.label) # self.out_gl0 = self.netdl(self.l2) # self.out_gl1 = self.netdl(self.l1) # self.out_gl = torch.cat([self.out_gl0, self.out_gl1], 0) self.out_gl1 = self.netdl(self.l1) # self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.label.data.resize_(self.opt.batchsize).fill_(self.real_label) # self.label_gl = torch.cat([self.labelf, self.label], 0) # self.err_gl_bce = self.bce_criterion(self.out_gl, self.label_gl) self.err_gl_bce = self.bce_criterion(self.out_gl1, self.label) self.err_g_mse=self.l2l_criterion(self.input, self.del1) self.err_g = self.err_gl_bce + self.err_gv_bce+self.err_g_mse self.err_g.backward(retain_graph=True) self.optimizer_en.step() self.optimizer_de.step() ## def update_netc(self): self.netc.zero_grad() self.out_c0=self.netc(self.del2.detach()) self.out_c1 = self.netc(self.input) # self.out_c1=self.netc(self.del1.detach()) self.out_c=torch.cat([self.out_c0,self.out_c1],0) self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.label.data.resize_(self.opt.batchsize).fill_(self.real_label) self.label_c=torch.cat([self.labelf,self.label],0) self.err_c_bce = self.bce_criterion(self.out_c, self.label_c) self.err_c_bce.backward() self.optimizer_c.step() def update_l2(self): for i in range(5): self.optimizer_l2.zero_grad() self.out_m=self.netc(self.del2) self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label) self.err_m_bce = self.bce_criterion(self.out_m, self.labelf) self.err_m_bce.backward() self.optimizer_l2.step() self.del2=self.netde(self.l2) def optimize(self): """ Optimize netD and netG networks. """ self.u = np.random.uniform(-1, 1, (self.opt.batchsize, self.opt.nz, 1, 1)) self.l2 = torch.from_numpy(self.u).float() self.n = torch.randn(self.opt.batchsize, self.opt.nc, self.opt.isize, self.opt.isize) self.l1 = self.neten(self.input + self.n) self.del1=self.netde(self.l1) self.del2=self.netde(self.l2) self.update_netc() self.update_netd() if self.opt.mine==True: self.update_l2() self.update_netg() ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_dv_bce', self.err_dv_bce.item()), ('err_dl_bce', self.err_dl_bce.item()), ('err_gv_bce', self.err_gv_bce.item()), ('err_gl_l1l', self.err_gl_bce.item()), ('err_g_mse', self.err_g_mse.item()), ('err_c_bce', self.err_c_bce.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.del1.data fixed = self.netde(self.neten(self.fixed_input))[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({'epoch': epoch + 1, 'state_dict': self.neten.state_dict()}, '%s/neten.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netde.state_dict()}, '%s/netde.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netdl.state_dict()}, '%s/netdl.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netdv.state_dict()}, '%s/netdv.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netc.state_dict()}, '%s/netc.pth' % (weight_dir)) ## def train_epoch(self): """ Train the model for one epoch. """ self.neten.train() self.netde.train() epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset) self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch+1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name()) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_epoch() res = self.test() # if res['AUC'] > best_auc: # best_auc = res['AUC'] # self.save_weights(self.epoch) if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name()) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netc.pth".format(self.name().lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netc.load_state_dict(pretrained_dict) except IOError: raise IOError("netc weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long) self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32) # print(" Testing model %s." % self.name()) self.times = [] self.total_steps = 0 epoch_iter = 0 # print(self.dataloader['test']) # print(type(self.dataloader['test'])) for i, data in enumerate(self.dataloader['test'], 0): # # print(data) self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.out_c = self.netc(self.input) time_o = time.time() self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + self.out_c.size(0)] = self.out_c self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+self.out_c.size(0)] = self.gt.reshape(self.out_c.size(0)) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.png' % (dst, i+1), normalize=True,nrow=4) vutils.save_image(fake, '%s/fake_%03d.png' % (dst, i+1), normalize=True,nrow=4) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores,metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance def test1(self): """ Test GANomaly model. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/neten.pth".format(self.name().lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.neten.load_state_dict(pretrained_dict) except IOError: raise IOError("netc weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long) self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),self.opt.nz), dtype=torch.float32) # print(" Testing model %s." % self.name()) self.times = [] self.total_steps = 0 epoch_iter = 0 # print(self.dataloader['test']) # print(type(self.dataloader['test'])) for i, data in enumerate(self.dataloader['test'], 0): # # print(data) self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.l1 = self.neten(self.input) time_o = time.time() self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + self.l1.size(0),:] = self.l1.reshape(self.l1.size(0),self.opt.nz) self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+self.l1.size(0)] = self.gt.reshape(self.l1.size(0)) self.times.append(time_o - time_i) x=self.an_scores.numpy() y=self.gt_labels.numpy() return x,y
class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, dataloader): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") self.sqrtnz = int(self.opt.nz**0.5) ## def set_input(self, input: torch.Tensor): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) self.visualizer.save_fixed_real_s(self.fixed_input) def z_set_input(self, net, input: torch.Tensor): with torch.no_grad(): if net == 'i': self.i_input.resize_(input[0].size()).copy_(input[0]) self.i_gt.resize_(input[1].size()).copy_(input[1]) self.i_label.resize_(input[1].size()) self.i_real_label.resize_(input[1].size()).copy_(input[1]) if net == 'o': self.o_input.resize_(input[0].size()).copy_(input[0]) self.o_gt.resize_(input[1].size()).copy_(input[1]) self.o_label.resize_(input[1].size()) self.o_real_label.resize_(input[1].size()).copy_(input[1]) ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def z_get_errors(self, net): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ if net == 'i': errors = OrderedDict([('err_i', self.err_i.item())]) if net == 'o': errors = OrderedDict([('err_o', self.err_o.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data # point fixed_reals = self.fixed_input.data # point return reals, fakes, fixed, fixed_reals ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netg.state_dict() }, '%s/netG.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netd.state_dict() }, '%s/netD.pth' % (weight_dir)) ## def train_one_epoch(self): """ Train the model for one epoch. """ self.netg.train() if self.opt.strengthen: self.netd.train() ## point epoch_iter = 0 for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) # self.optimize() self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.dataloader['train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: # point reals, fakes, fixed, fixed_reals = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images( reals, fakes, fixed, fixed_reals) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print(">> Training model %s." % self.name) for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.train_one_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ if self.opt.strengthen: self.netg.eval() with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len( self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.d_pred = torch.zeros(size=(len( self.dataloader['test'].dataset), ), dtype=torch.float32, device=self.device) self.last_feature = torch.zeros( size=(len(self.dataloader['test'].dataset), list(self.netd.children())[0][-3].out_channels, list(self.netd.children())[0][-3].kernel_size[0], list(self.netd.children())[0][-3].kernel_size[1]), dtype=torch.float32, device=self.device) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) d_pred, features = self.netd(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.d_pred[i * self.opt.batchsize:i * self.opt.batchsize + d_pred.size(0)] = d_pred.reshape(d_pred.size(0)) self.last_feature[ i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = features.reshape( error.size(0), list(self.netd.children())[0][-3].out_channels, list(self.netd.children())[0][-3].kernel_size[0], list(self.netd.children())[0][-3].kernel_size[1]) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _, _ = self.get_current_images( ) # point add attribute fixed_real vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) """ data=[] feature = self.last_feature.cpu().numpy().reshape(self.last_feature.size()[0], -1) label = self.gt_labels.cpu().numpy().reshape(self.last_feature.size()[0], -1) features_dir = './features' file_name = 'features_map.csv' feature_path = os.path.join(features_dir, file_name + '.txt') import pandas as pd feature.tolist() label.tolist() test = pd.DataFrame(data=feature) test.to_csv("./feature.csv", mode='a+', index=None, header=None) test = pd.DataFrame(data=label) test.to_csv("./label.csv", mode='a+', index=None, header=None) print('END') """ # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.strengthen and self.opt.phase == 'test': t0 = threading.Thread( target=self.visualizer.display_scores_histo, name='histogram ', args=(self.epoch, self.an_scores, self.gt_labels)) t0.start() if self.opt.strengthen > 1: t1 = threading.Thread( target=self.visualizer.display_feature, name='t-SNE visualizer', args=(self.last_feature, self.gt_labels)) t2 = threading.Thread( target=self.visualizer.display_latent, name='latent LDA visualizer', args=(self.latent_i, self.latent_o, self.gt_labels, 9, 1000, True)) t1.start() t2.start() if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) if self.opt.classifier: self.z_dataloader = set_dataset(self.opt, self.latent_i, self.latent_o, self.gt_labels) return performance ## def z_save_weights(self, epoch): weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netc_i.state_dict() }, '%s/netC_i.pth' % (weight_dir)) torch.save({ 'epoch': epoch + 1, 'state_dict': self.netc_o.state_dict() }, '%s/netC_o.pth' % (weight_dir)) ## def z_train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best = 0 # Train for niter epochs. print(">> Training model classifier") for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.z_train_one_epoch() res = self.z_test() if res[self.opt.z_metric] > best: best = res[self.opt.z_metric] self.z_save_weights(self.epoch) self.visualizer.print_current_performance(res, best, self.opt.z_metric) self.visualizer.record_best(best, self.opt.z_metric, self.opt.abnormal_class, self.opt.manualseed, 'ganomaly_s') print(">> Training model %s.[Done]" % self.name) ## def z_train_one_epoch(self): """ Train the model for one epoch. """ self.netc_i.train() self.netc_o.train() epoch_iter = 0 for data in tqdm(self.z_dataloader['i_train'], leave=False, total=len(self.z_dataloader['i_train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.z_set_input('i', data) # self.optimize() self.z_optimize_params('i') for data in tqdm(self.z_dataloader['o_train'], leave=False, total=len(self.z_dataloader['o_train'])): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.z_set_input('o', data) # self.optimize() self.z_optimize_params('o') if self.total_steps % self.opt.print_freq == 0: errors = self.z_get_errors('i') if self.opt.display: counter_ratio = float(epoch_iter) / len( self.z_dataloader['i_train'].dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) # if self.total_steps % self.opt.save_image_freq == 0: # # point # reals, fakes, fixed, fixed_reals = self.get_current_images() # self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) # if self.opt.display: # self.visualizer.display_current_images(reals, fakes, fixed, fixed_reals) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) ## def z_test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ self.netd.eval() self.netc_i.eval() self.netc_o.eval() with torch.no_grad(): # Load the weights of netg and netd. if self.opt.z_load_weights: d_path = "./output/{}/{}/train/weights/netD.pth".format( self.name.lower(), self.opt.dataset) i_path = "./output/{}/{}/train/weights/netC_i.pth".format( self.name.lower(), self.opt.dataset) o_path = "./output/{}/{}/train/weights/netC_o.pth".format( self.name.lower(), self.opt.dataset) d_pretrained_dict = torch.load(d_path)['state_dict'] i_pretrained_dict = torch.load(i_path)['state_dict'] o_pretrained_dict = torch.load(o_path)['state_dict'] try: self.netd.load_state_dict(d_pretrained_dict) self.netc_i.load_state_dict(i_pretrained_dict) self.netc_o.load_state_dict(o_pretrained_dict) except IOError: raise IOError("net weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.i_pred = torch.zeros(size=(len( self.z_dataloader['i_test'].dataset), ), dtype=torch.float32, device=self.device) self.o_pred = torch.zeros(size=(len( self.z_dataloader['o_test'].dataset), ), dtype=torch.float32, device=self.device) self.i_gt_labels = torch.zeros(size=(len( self.z_dataloader['i_test'].dataset), ), dtype=torch.long, device=self.device) self.o_gt_labels = torch.zeros(size=(len( self.z_dataloader['o_test'].dataset), ), dtype=torch.long, device=self.device) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.z_dataloader['i_test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.z_set_input('i', data) i_pred = self.netc_i(self.i_input) time_o = time.time() self.i_pred[i * self.opt.batchsize:i * self.opt.batchsize + i_pred.size(0)] = i_pred.reshape(i_pred.size(0)) self.i_gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + i_pred.size(0)] = self.i_gt.reshape( self.i_gt.size(0)) self.times.append(time_o - time_i) for i, data in enumerate(self.z_dataloader['o_test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.z_set_input('o', data) o_pred = self.netc_o(self.o_input) time_o = time.time() self.o_pred[i * self.opt.batchsize:i * self.opt.batchsize + o_pred.size(0)] = o_pred.reshape(o_pred.size(0)) self.o_gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + o_pred.size(0)] = self.o_gt.reshape( self.o_gt.size(0)) self.times.append(time_o - time_i) # Save test images. # print(auprc(self.i_gt_labels.cpu(), self.i_pred.cpu())) # print((self.i_gt_labels.cpu()[:10], self.i_pred.cpu())[:10]) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # auc, eer = roc(self.gt_labels, self.an_scores) self.pred_c = self.i_pred.cpu() * self.opt.w_i + \ self.o_pred.cpu() * self.opt.w_o # print(self.pred_c[:5]) # print(self.i_gt_labels[:5]) scores = evaluate(self.o_gt_labels.cpu(), self.pred_c, self.opt.z_metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), (self.opt.z_metric, scores)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.z_dataloader['i_test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance
class Skipganomaly: """Skip-GANomaly Class """ @property def name(self): return 'skip-ganomaly' def __init__(self, opt, data=None): # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.data = data self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.inf_dir = os.path.join(self.opt.outf, self.opt.name, 'inference') self.device = torch.device( "cuda:0" if self.opt.device != "cpu" else "cpu") # -- Misc attributes self.epoch = 0 self.times = [] self.total_steps = 0 ## # Create and initialize networks from networks.py. self.netg = define_G(self.opt, norm='batch', use_dropout=False, init_type='normal') self.netd = define_D(self.opt, norm='batch', use_sigmoid=False, init_type='normal') ## #resume Training if self.opt.resume != '': print("\nLoading pre-trained networks.") self.opt.iter = torch.load( os.path.join(self.opt.resume, 'netG.pth'))['epoch'] self.netg.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netG.pth'))['state_dict']) self.netd.load_state_dict( torch.load(os.path.join(self.opt.resume, 'netD.pth'))['state_dict']) print("\tDone.\n") print(self.netg) print(self.netd) ## # Loss Functions self.l_adv = nn.BCELoss() self.l_con = nn.L1Loss() self.l_lat = l2_loss ## # Initialize input tensors. self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.noise = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.label = torch.empty(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.gt = torch.empty(size=(opt.batchsize, ), dtype=torch.long, device=self.device) self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32, device=self.device) self.real_label = torch.ones(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) self.fake_label = torch.zeros(size=(self.opt.batchsize, ), dtype=torch.float32, device=self.device) ## # Setup optimizer if self.opt.phase == "train": self.netg.train() self.netd.train() self.optimizers = [] self.optimizer_d = optim.Adam(self.netd.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizer_g = optim.Adam(self.netg.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) self.optimizers.append(self.optimizer_d) self.optimizers.append(self.optimizer_g) self.schedulers = [ get_scheduler(optimizer, opt) for optimizer in self.optimizers ] ## def set_input(self, input: torch.Tensor, noise: bool = False): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Add noise to the input. if noise: self.noise.data.copy_(torch.randn(self.noise.size())) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d), ('err_g', self.err_g), ('err_g_adv', self.err_g_adv), ('err_g_con', self.err_g_con), ('err_g_lat', self.err_g_lat)]) return errors ## def reinit_d(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch: int, is_best: bool = False): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ name = self.opt.dataset if self.opt.dataset else self.opt.dataroot.split( "/")[-1] weight_dir = os.path.join(self.opt.outf, name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) if is_best: torch.save({ 'epoch': epoch, 'state_dict': self.netg.state_dict() }, f'{weight_dir}/netG_best.pth') torch.save({ 'epoch': epoch, 'state_dict': self.netd.state_dict() }, f'{weight_dir}/netD_best.pth') else: torch.save({ 'epoch': epoch, 'state_dict': self.netd.state_dict() }, f"{weight_dir}/netD_{epoch}.pth") torch.save({ 'epoch': epoch, 'state_dict': self.netg.state_dict() }, f"{weight_dir}/netG_{epoch}.pth") def load_weights(self, epoch=None, is_best: bool = False, path=None): """ Load pre-trained weights of NetG and NetD Keyword Arguments: epoch {int} -- Epoch to be loaded (default: {None}) is_best {bool} -- Load the best epoch (default: {False}) path {str} -- Path to weight file (default: {None}) Raises: Exception -- [description] IOError -- [description] """ if epoch is None and is_best is False: raise Exception( 'Please provide epoch to be loaded or choose the best epoch.') if is_best: fname_g = f"netG_best.pth" fname_d = f"netD_best.pth" else: fname_g = f"netG_{epoch}.pth" fname_d = f"netD_{epoch}.pth" if path is None: name = self.opt.dataset if self.opt.dataset else self.opt.dataroot.split( "/")[-1] path_g = f"{self.opt.outf}/{name}/train/weights/{fname_g}" path_d = f"{self.opt.outf}/{name}/train/weights/{fname_d}" else: path_g = path + "/" + fname_g path_d = path + "/" + fname_d # Load the weights of netg and netd. print('>> Loading weights...') if len(self.opt.gpu_ids) == 0: weights_g = torch.load( path_g, map_location=lambda storage, loc: storage)['state_dict'] weights_d = torch.load( path_d, map_location=lambda storage, loc: storage)['state_dict'] else: weights_g = torch.load(path_g)['state_dict'] weights_d = torch.load(path_d)['state_dict'] try: # create new OrderedDict that does not contain `module.` new_weights_g = OrderedDict() new_weights_d = OrderedDict() for k, v in weights_g.items(): name = k[7:] # remove `module.` new_weights_g[name] = v for k, v in weights_d.items(): name = k[7:] # remove `module.` new_weights_d[name] = v # load params if len(self.opt.gpu_ids) == 0: weights_g = new_weights_g weights_d = new_weights_d self.netg.load_state_dict(weights_g) self.netd.load_state_dict(weights_d) except IOError: raise IOError("netG weights not found") print(' Done.') def forward(self): self.forward_g() self.forward_d() def forward_g(self): """ Forward propagate through netG """ #TODO: Check, why noised input is used self.fake = self.netg(self.input + self.noise) def forward_d(self): """ Forward propagate through netD """ self.pred_real, self.feat_real = self.netd(self.input) self.pred_fake, self.feat_fake = self.netd(self.fake) def backward_g(self): """ Backpropagate netg """ self.err_g_adv = self.opt.w_adv * self.l_adv(self.pred_fake, self.real_label) self.err_g_con = self.opt.w_con * self.l_con(self.fake, self.input) self.err_g_lat = self.opt.w_lat * self.l_lat( self.feat_fake, self.feat_real) # should be named discriminator if self.opt.verbose: print(f'err_g_adv: {str(self.err_g_adv)}') print(f'err_g_con: {str(self.err_g_con)}') print(f'err_g_lat: {str(self.err_g_lat)}') self.err_g = self.err_g_adv + self.err_g_con + self.err_g_lat self.err_g.backward(retain_graph=True) def backward_d(self): # Fake #print(f'pref_fake: {str(self.pred_fake)}') #print(f'self.fake_label: {str(self.fake_label)}') #print(f'self.pred_real: {str(self.pred_real)}') #print(f'self.real_label: {str(self.real_label)}') self.err_d_fake = self.l_adv(self.pred_fake, self.fake_label) # Real # pred_real, feat_real = self.netd(self.input) self.err_d_real = self.l_adv(self.pred_real, self.real_label) # Combine losses. # TODO: According to https://github.com/samet-akcay/skip-ganomaly/issues/18#issue-728932038 ... Check if lat loss has to be negative in discriminator backprob if self.opt.verbose: print(f'err_d_real: {str(self.err_d_real)}') print(f'err_d_fake: {str(self.err_d_fake)}') print(f'err_g_lat: {str(self.err_g_lat)}') self.err_d = self.err_d_real + self.err_d_fake + self.err_g_lat self.err_d.backward(retain_graph=True) def update_netg(self): """ Update Generator Network. """ self.optimizer_g.zero_grad() self.backward_g() def update_netd(self): """ Update Discriminator Network. """ self.optimizer_d.zero_grad() self.backward_d() ## def optimize_params(self): """ Optimize netD and netG networks. """ self.forward() self.update_netg() self.update_netd() self.optimizer_g.step() self.optimizer_d.step() if self.err_d < 1e-5: self.reinit_d() ## def train_one_epoch(self): """ Train the model for one epoch. """ self.opt.phase = "train" self.netg.train() self.netd.train() epoch_iter = 0 for data in tqdm(self.data.train, leave=False, total=len(self.data.train)): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize_params() reals, fakes, fixed = self.get_current_images() errors = self.get_errors() if self.opt.display: self.visualizer.plot_current_errors(self.epoch, self.total_steps, errors) # Write images to tensorboard if self.total_steps % self.opt.save_image_freq == 0: self.visualizer.display_current_images( reals, fakes, fixed, train_or_test="train", global_step=self.total_steps) if self.total_steps % self.opt.save_image_freq == 0: self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print( f">> Training {self.name} on {self.opt.dataset} to detect anomalies" ) for self.epoch in range(self.opt.iter, self.opt.niter): self.train_one_epoch() res = self.test() if res['auc'] > best_auc: best_auc = res['auc'] self.save_weights(self.epoch, is_best=True) self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) ## def test(self, plot_hist=True): """ Test GANomaly model. Args: data ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ self.netg.eval() self.netd.eval() with torch.no_grad(): # Load the weights of netg and netd. if self.opt.path_to_weights is not None: self.load_weights(path=self.opt.path_to_weights, is_best=True) self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len(self.data.valid.dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len(self.data.valid.dataset), ), dtype=torch.long, device=self.device) print(" Testing %s" % self.name) self.times = [] total_steps_test = 0 epoch_iter = 0 i = 0 for data in tqdm(self.data.valid, leave=False, total=len(self.data.valid)): total_steps_test += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() # Forward - Pass self.forward_for_testing(data) # Calculate the anomaly score. error = self.calculate_an_score() time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) if self.opt.verbose: print(f'an_scores: {str(self.an_scores)}') self.times.append(time_o - time_i) real, fake, fixed = self.get_current_images() if self.epoch * len( self.data.valid ) + total_steps_test % self.opt.save_image_freq == 0: self.visualizer.display_current_images( real, fake, fixed, train_or_test="test", global_step=self.epoch * len(self.data.valid) + total_steps_test) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) #iterate over them (real) and write anomaly score and ground truth on filename vutils.save_image(real, '%s/real_%03d.png' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.png' % (dst, i + 1), normalize=True) i = i + 1 # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times * 1000) # Scale error vector between [0, 1] # self.an_scores = (self.an_scores - torch.min(self.an_scores))/(torch.max(self.an_scores) - torch.min(self.an_scores)) if self.opt.verbose: print(f'scaled an_scores: {str(self.an_scores)}') y_trues = self.gt_labels.cpu() y_preds = self.an_scores.cpu() # Create data frame for scores and labels. performance, thresholds, y_preds_man, y_preds_auc = get_performance( y_trues=y_trues, y_preds=y_preds, manual_threshold=self.opt.decision_threshold) with open( os.path.join(self.opt.outf, self.opt.phase + "_results.txt"), "a+") as f: f.write(str(performance)) f.write("\n") f.close() self.visualizer.plot_histogram( y_trues=y_trues, y_preds=y_preds, threshold=performance["threshold"], save_path=os.path.join( self.opt.outf, "histogram_test" + str(self.epoch) + ".png"), tag="Histogram_Test", global_step=self.epoch) self.visualizer.plot_pr_curve(y_trues=y_trues, y_preds=y_preds, thresholds=thresholds, global_step=self.epoch, tag="PR_Curve_Test") self.visualizer.plot_roc_curve( y_trues=y_trues, y_preds=y_preds, global_step=self.epoch, tag="ROC_Curve_Test", save_path=os.path.join(self.opt.outf, "roc_test" + str(self.epoch) + ".png")) self.visualizer.plot_current_conf_matrix( self.epoch, performance["conf_matrix"], tag="Confusion_Matrix_Test", save_path=os.path.join(self.opt.outf, self.opt.phase + "_conf_matrix.png")) self.visualizer.plot_performance(self.epoch, 0, performance, tag="Performance_Test") return performance def forward_for_testing(self, data): self.set_input(data) self.fake = self.netg(self.input) real_clas, self.feat_real = self.netd(self.input) fake_clas, self.feat_fake = self.netd(self.fake) def inference(self): self.netg.eval() self.netd.eval() with torch.no_grad(): self.load_weights(path=self.opt.path_to_weights, is_best=True) # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len( self.data.inference.dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len( self.data.inference.dataset), ), dtype=torch.long, device=self.device) print("Starting Inference!") inf_time = None inf_times = [] self.file_names = [] for i, data in tqdm(enumerate(self.data.inference), leave=False, total=len(self.data.inference)): inf_start = time.time() # Forward - Pass self.forward_for_testing(data) # Calculate the anomaly score. error = self.calculate_an_score() inf_times.append(time.time() - inf_start) self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) if self.opt.verbose: print(f'an_scores: {str(self.an_scores)}') real, fake, fixed = self.get_current_images() if i % self.opt.save_image_freq == 0: self.visualizer.display_current_images( real, fake, fixed, train_or_test="test_inference", global_step=i) self.file_names.append(data[2]) # Measure inference time. # Scale error vector between [0, 1] TODO: does it work without normalizing? # self.an_scores = (self.an_scores - torch.min(self.an_scores))/(torch.max(self.an_scores) - torch.min(self.an_scores)) if self.opt.verbose: print(f'scaled an_scores: {str(self.an_scores)}') y_trues = self.gt_labels.cpu() y_preds = self.an_scores.cpu() inf_time = sum(inf_times) print(f'Inference time: {inf_time} secs') print(f'Inference time / individual: {inf_time/len(y_trues)} secs') # Create data frame for scores and labels. performance, thresholds, y_preds_man, y_preds_auc = get_performance( y_trues=y_trues, y_preds=y_preds, manual_threshold=self.opt.decision_threshold) with open(os.path.join(self.opt.outf, self.opt.phase + "_results.txt"), "w") as f: f.write(str(performance)) f.close() self.visualizer.plot_histogram(y_trues=y_trues, y_preds=y_preds, threshold=performance["threshold"], save_path=os.path.join( self.opt.outf, "histogram_inference.png"), tag="Histogram_Inference") self.visualizer.plot_pr_curve(y_trues=y_trues, y_preds=y_preds, thresholds=thresholds, global_step=1, tag="PR_Curve_Inference") self.visualizer.plot_roc_curve(y_trues=y_trues, y_preds=y_preds, global_step=1, tag="ROC_Curve_Inference", save_path=os.path.join( self.opt.outf, "roc_inference.png")) self.visualizer.plot_current_conf_matrix( 1, performance["conf_matrix"], tag="Confusion_Matrix_Inference", save_path=os.path.join(self.opt.outf, self.opt.phase + "_conf_matrix.png")) if self.opt.decision_threshold: self.visualizer.plot_current_conf_matrix( 2, performance["conf_matrix_man"], save_path=os.path.join(self.opt.outf, self.opt.phase + "_conf_matrix_man.png")) self.visualizer.plot_histogram( y_trues=y_trues, y_preds=y_preds, threshold=performance["manual_threshold"], global_step=2, save_path=os.path.join(self.opt.outf, "histogram_inference_man.png"), tag="Histogram_Inference") write_inference_result(file_names=self.file_names, y_trues=y_trues, y_preds=y_preds_man, outf=os.path.join( self.opt.outf, "classification_result_man.json")) self.visualizer.plot_performance(1, 0, performance, tag="Performance_Inference") write_inference_result(file_names=self.file_names, y_trues=y_trues, y_preds=y_preds_auc, outf=os.path.join(self.opt.outf, "classification_result.json")) ## # RETURN return performance def calculate_an_score(self): si = self.input.size() sz = self.feat_real.size() rec = (self.input - self.fake).view(si[0], si[1] * si[2] * si[3]) lat = (self.feat_real - self.feat_fake).view(sz[0], sz[1] * sz[2] * sz[3]) rec = torch.mean(torch.pow(rec, 2), dim=1) lat = torch.mean(torch.pow(lat, 2), dim=1) #print("lat", lat) #print("rec", rec) if self.opt.verbose: print(f'rec: {str(rec)}') print(f'lat: {str(lat)}') error = 0.9 * rec + 0.1 * lat return error
class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, dataloader): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.dataloader = dataloader self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test') self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu") ## def set_input(self, input:torch.Tensor): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([ ('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_enc', self.err_g_enc.item())]) return errors ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) print(weight_dir) torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()}, '%s/netG.pth' % (weight_dir)) torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()}, '%s/netD.pth' % (weight_dir)) ## def train_one_epoch(self, epochNUM): """ Train the model for one epoch. """ self.netg.train() epoch_iter = 0 ii = 0 num = len(self.dataloader['train']) for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])): self.opt.signalInfo.emit(10 + 0.8 * 85 * (epochNUM / self.opt.niter)*(ii/num),"") self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) # self.optimize() self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset) self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) ii += 1 message = ">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter) print(message) #self.opt.showText.append(message+"\n"); # self.visualizer.print_current_errors(self.epoch, errors) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 best_info = None # Train for niter epochs. print(">> Training model %s." % self.name) self.opt.signalInfo.emit(-1,">> Training model {}.".format(self.name)) i = 0 for self.epoch in range(self.opt.iter, self.opt.niter): # Train for one epoch self.opt.signalInfo.emit(-1,'正在进行第{}个epoch的训练....'.format(i + 1)) num = self.train_one_epoch(i+1) i += 1 # self.save_weights(self.epoch) self.opt.signalInfo.emit(10 + 0.8*85* (i / self.opt.niter),'第{}个epoch的训练完毕!\n正在对该epoch进行测试....'.format(i)) res,info = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] best_info = info self.save_weights(self.epoch) infoSTR = "" for key,value in info.items(): infoSTR += str(key)+":"+str(value)+"\n" self.opt.signalInfo.emit(10 + 85* (i/ self.opt.niter), '测试完毕!\n第{}个epoch训练结果:\n{}'.format(i, infoSTR)) # self.visualizer.print_current_performance(res, best_auc) print(">> Training model %s.[Done]" % self.name) self.opt.signalInfo.emit(-1,">> Training model {}.[Done]".format(self.name)) # dict_info = {} # dict_info['minVal'] = 0.1 # dict_info['maxVal'] = 0.8 # dict_info['proline'] = 0.40 # dict_info['auc'] = 0.93 # dict_info['Avg Run Time (ms/batch)'] = 9 # best_info = dict_info return best_info ## def test(self): """ Test GANomaly model. Args: dataloader ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' #self.opt.showProcess.setValue(80) # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.dataloader['test'], 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1) time_o = time.time() self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz) self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] print(torch.min(self.an_scores)) print(torch.max(self.an_scores)) maxNUM = torch.max(self.an_scores) minNUM = torch.min(self.an_scores) self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores)) # auc, eer = roc(self.gt_labels, self.an_scores) # -------------- 处理阈值 ------------------ print('-------------- 处理阈值 ------------------') print(len(self.gt_labels)) plt.ion() scores = {} ##plt.ion() # Create data frame for scores and labels. scores['scores'] = self.an_scores scores['labels'] = self.gt_labels hist = pd.DataFrame.from_dict(scores) #hist.to_csv("histogram.csv") # Filter normal and abnormal scores. abn_scr = hist.loc[hist.labels == 1]['scores'] nrm_scr = hist.loc[hist.labels == 0]['scores'] # Create figure and plot the distribution. ##fig, axes = plt.subplots(figsize=(4, 4)) b = [] c = [] # for i in range(1000): # b.append(nrm_scr[i]) # for j in range(1000, 3011): # c.append(abn_scr[j]) print('asasddda') print(len(nrm_scr)) print(len(abn_scr)) for i in nrm_scr: b.append(i) for j in abn_scr: c.append(j) ##sns.distplot(nrm_scr, label=r'Normal Scores', color='r', bins=100, hist=True) ##sns.distplot(abn_scr, label=r'Abnormal Scores', color='b', bins=100, hist=True) nrm = np.zeros((50), dtype=np.int) minfix = 0.4 abn = np.zeros((50), dtype=np.int) abmin = 30 for k in np.arange(0, 1, 0.02): kint = int(k * 50) for j in range(len(nrm_scr)): if b[j] >= k and b[j] < (k + 0.02): nrm[kint] = nrm[kint] + 1 for j in range(len(abn_scr)): if c[j] >= k and c[j] < (k + 0.02): abn[kint] = abn[kint] + 1 print(nrm) print(abn) # startInd = 3 # for k in range(0,20): # if abs(nrm[k] - abn[k]) <= 3: # continue # else: # startInd = k # max_dist = (len(nrm) + len(abn))*0.28 # for k in range(startInd, 20): # if abs(nrm[k] - abn[k]) < 5: # #max_dist = abs(nrm[k] - abn[k]) # minfix = round((k / 20) + 0.02, 3) # break # for k in range(3, 17): # # print(nrm[k]) # # print(abn[k]) # # print('----') # if abs(nrm[k]-abn[k]) > abmin and not (nrm[k] == 0 and abn[k] == 0): # abmin = abs(nrm[k] - abn[k]) # minfix = round((k / 20) + 0.02, 3) max_dist = (len(nrm) + len(abn)) * 0.25 for k in range(0,50): num1 = np.sum(nrm[0:k]) num2 = np.sum(abn[k::]) if (num1 + num2) >= max_dist: minfix = round((k / 50) + 0.05, 3) max_dist = num1+num2 proline = minfix print(proline) print(self.gt_labels[0:20]) print(self.an_scores[0:20]) print('------------- 处理阈值 END --------------') # ------------- 处理阈值 END -------------- auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) # --- 写入文件 --- dict_info = {} dict_info['minVal'] = float(minNUM.item()) dict_info['maxVal'] = float(maxNUM.item()) dict_info['proline'] = float(proline) dict_info['auc'] = float(auc) dict_info['Avg Run Time (ms/batch)'] = float(self.times) #self.opt.showText.append(str(performance)); #self.opt.showProcess.setValue(100) return performance, dict_info def FinalTest(self, minVal, maxVal,threshold=0.2): path = "./output/{}/{}/train/weights/netG.pth".format('ganomaly', self.opt.dataset) print('***'*10) print(path) print('***' * 10) pretrained_dict = torch.load(path)['state_dict'] #self.opt.showText.append('Loading Weights...') self.opt.signalInfo.emit(-1, '加载权重...') try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') #self.opt.showText.append('LoadedWeights') self.opt.signalInfo.emit(5,"权重加载完毕!\n正在加载图片....") #self.opt.showText.append('正在加载图片...') path2 = self.opt.dataroot + '/final_test/' transform = transforms.Compose([transforms.Resize(self.opt.isize), transforms.CenterCrop(self.opt.isize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) testdata = ImageFolder(os.path.join(path2), transform) #print(testdata) testdata2 = torch.utils.data.DataLoader(dataset=testdata, batch_size=1, shuffle=False, num_workers=int(self.opt.workers), drop_last=True) #print(testdata2) self.opt.signalInfo.emit(10, "图片加载完毕!\n正在进行检测....") testFilesName = os.listdir(self.opt.dataroot+'/final_test/0') testFilesRes = {} testFilesResNor = {} testFilesResAbn = {} for i, data2 in enumerate(testdata2, 0): self.set_input(data2) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) #print(error) testscore = (error - minVal) / (maxVal - minVal) testFilesRes[testFilesName[i]] = testscore.item() if testscore.item() >= threshold: testFilesResAbn[testFilesName[i]] = testscore.item() else: testFilesResNor[testFilesName[i]] = testscore.item() self.opt.signalInfo.emit(10+(i+1)/len(testFilesName)*88,"") #self.opt.showProcess.setValue() # print(testFilesRes) # print(testFilesResNor) # print(testFilesResAbn) self.opt.signalInfo.emit(100, "图片检测完毕!") #self.opt.signal.emit(len(testFilesResNor), len(testFilesResAbn)) torch.cuda.empty_cache() return testFilesResNor,testFilesResAbn