class BaseModel(): """ Base Model for ganomaly """ def __init__(self, opt, data): ## # Seed for deterministic behavior self.seed(opt.manualseed) # Initalize variables. self.opt = opt self.visualizer = Visualizer(opt) self.data = data self.trn_dir = os.path.join(self.opt.outfolder, self.opt.name, 'train') self.tst_dir = os.path.join(self.opt.outfolder, self.opt.name, 'test') self.device = torch.device( "cuda:0" if self.opt.device != 'cpu' else "cpu") ## def seed(self, seed_value): """ Seed Arguments: seed_value {int} -- [description] """ # Check if seed is default value if seed_value == -1: return # Otherwise seed all functionality import random random.seed(seed_value) torch.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True ## def set_input(self, input: torch.Tensor, noise: bool = False): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Add noise to the input. if noise: self.noise.data.copy_(torch.randn(self.noise.size())) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ## def get_errors(self): """ Get netD and netG errors. Returns: [OrderedDict]: Dictionary containing errors. """ errors = OrderedDict([('err_d', self.err_d.item()), ('err_g', self.err_g.item()), ('err_g_adv', self.err_g_adv.item()), ('err_g_con', self.err_g_con.item()), ('err_g_lat', self.err_g_lat.item())]) return errors ## def reinit_d(self): """ Initialize the weights of netD """ self.netd.apply(weights_init) print('Reloading d net') ## def get_current_images(self): """ Returns current images. Returns: [reals, fakes, fixed] """ reals = self.input.data fakes = self.fake.data fixed = self.netg(self.fixed_input)[0].data return reals, fakes, fixed ## def save_weights(self, epoch: int, is_best: bool = False): """Save netG and netD weights for the current epoch. Args: epoch ([int]): Current epoch number. """ weight_dir = os.path.join(self.opt.outfolder, self.opt.name, 'train', 'weights') if not os.path.exists(weight_dir): os.makedirs(weight_dir) if is_best: torch.save({ 'epoch': epoch, 'state_dict': self.netg.state_dict() }, f'{weight_dir}/netG_best.pth') torch.save({ 'epoch': epoch, 'state_dict': self.netd.state_dict() }, f'{weight_dir}/netD_best.pth') else: torch.save({ 'epoch': epoch, 'state_dict': self.netd.state_dict() }, f"{weight_dir}/netD_{epoch}.pth") torch.save({ 'epoch': epoch, 'state_dict': self.netg.state_dict() }, f"{weight_dir}/netG_{epoch}.pth") def load_weights(self, epoch=None, is_best: bool = False, path=None): """ Load pre-trained weights of NetG and NetD Keyword Arguments: epoch {int} -- Epoch to be loaded (default: {None}) is_best {bool} -- Load the best epoch (default: {False}) path {str} -- Path to weight file (default: {None}) Raises: Exception -- [description] IOError -- [description] """ if epoch is None and is_best is False: raise Exception( 'Please provide epoch to be loaded or choose the best epoch.') if is_best: fname_g = f"netG_best.pth" fname_d = f"netD_best.pth" else: fname_g = f"netG_{epoch}.pth" fname_d = f"netD_{epoch}.pth" if path is None: path_g = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_g}" path_d = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_d}" # Load the weights of netg and netd. print('>> Loading weights...') weights_g = torch.load(path_g)['state_dict'] weights_d = torch.load(path_d)['state_dict'] try: self.netg.load_state_dict(weights_g) self.netd.load_state_dict(weights_d) except IOError: raise IOError("netG weights not found") print(' Done.') ## def train_one_epoch(self): """ Train the model for one epoch. """ self.netg.train() self.netd.train() epoch_iter = 0 for data in tqdm(self.data.train, leave=False, total=len(self.data.train)): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize self.set_input(data) self.optimize_params() if self.total_steps % self.opt.print_freq == 0: errors = self.get_errors() if self.opt.display: counter_ratio = float(epoch_iter) / len( self.data.train.dataset) self.visualizer.plot_current_errors( self.epoch, counter_ratio, errors) if self.total_steps % self.opt.save_image_freq == 0: reals, fakes, fixed = self.get_current_images() self.visualizer.save_current_images(self.epoch, reals, fakes, fixed) if self.opt.display: self.visualizer.display_current_images(reals, fakes, fixed) print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch + 1, self.opt.niter)) ## def train(self): """ Train the model """ ## # TRAIN self.total_steps = 0 best_auc = 0 # Train for niter epochs. print( f">> Training {self.name} on {self.opt.dataset} to detect {self.opt.abnormal_class}" ) for self.epoch in range(self.opt.iter, self.opt.niter): self.train_one_epoch() res = self.test() if res['AUC'] > best_auc: best_auc = res['AUC'] self.save_weights(self.epoch) self.visualizer.print_current_performance(res, best_auc) loss = 'gloss: ' + str(self.err_g.item()) + ' dloss:' + str( self.err_d.item()) print(loss) self.visualizer.write_to_log_file(loss) print(">> Training model %s.[Done]" % self.name) ## def test(self): """ Test GANomaly model. Args: data ([type]): Dataloader for the test set Raises: IOError: Model weights not found. """ with torch.no_grad(): # Load the weights of netg and netd. if self.opt.load_weights: path = "./output/{}/{}/train/weights/netG.pth".format( self.name.lower(), self.opt.dataset) pretrained_dict = torch.load(path)['state_dict'] try: self.netg.load_state_dict(pretrained_dict) except IOError: raise IOError("netG weights not found") print(' Loaded weights.') self.opt.phase = 'test' # Create big error tensor for the test set. self.an_scores = torch.zeros(size=(len(self.data.valid.dataset), ), dtype=torch.float32, device=self.device) self.gt_labels = torch.zeros(size=(len(self.data.valid.dataset), ), dtype=torch.long, device=self.device) self.latent_i = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device) self.latent_o = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device) # print(" Testing model %s." % self.name) self.times = [] self.total_steps = 0 epoch_iter = 0 for i, data in enumerate(self.data.valid, 0): self.total_steps += self.opt.batchsize epoch_iter += self.opt.batchsize time_i = time.time() self.set_input(data) self.fake, latent_i, latent_o = self.netg(self.input) error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1) time_o = time.time() self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = error.reshape(error.size(0)) self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0)] = self.gt.reshape(error.size(0)) self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_i.reshape( error.size(0), self.opt.nz) self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize + error.size(0), :] = latent_o.reshape( error.size(0), self.opt.nz) self.times.append(time_o - time_i) # Save test images. if self.opt.save_test_images: dst = os.path.join(self.opt.outfolder, self.opt.name, 'test', 'images') if not os.path.isdir(dst): os.makedirs(dst) real, fake, _ = self.get_current_images() vutils.save_image(real, '%s/real_%03d.eps' % (dst, i + 1), normalize=True) vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i + 1), normalize=True) # Measure inference time. self.times = np.array(self.times) self.times = np.mean(self.times[:100] * 1000) # Scale error vector between [0, 1] self.an_scores = (self.an_scores - torch.min(self.an_scores)) / ( torch.max(self.an_scores) - torch.min(self.an_scores)) auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric) performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)]) if self.opt.display_id > 0 and self.opt.phase == 'test': counter_ratio = float(epoch_iter) / len( self.data.valid.dataset) self.visualizer.plot_performance(self.epoch, counter_ratio, performance) return performance ## def update_learning_rate(self): """ Update learning rate based on the rule provided in options. """ for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print(' LR = %.7f' % lr)