def test_AE(ae_net, dataset, batch_size=16, n_jobs_dataloader=4, device='cuda'): # make test dataloader using image and mask loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \ shuffle=True, num_workers=n_jobs_dataloader) # MSE loss without reduction --> MSE loss for each output pixels criterion = MaskedMSELoss(reduction='none') # set to device ae_net = ae_net.to(device) criterion = criterion.to(device) # Testing epoch_loss = 0.0 n_batch = 0 start_time = time.time() idx_label_score = [] # put network in evaluation mode ae_net.eval() with torch.no_grad(): for b, data in enumerate(loader): input, label, mask, _, idx = data # put inputs to device input, label = input.to(device).float(), label.to(device) mask, idx = mask.to(device), idx.to(device) rec = ae_net(input) rec_loss = criterion(rec, input, mask) score = torch.mean(rec_loss, dim=tuple(range( 1, rec.dim()))) # mean over all dimension per batch # append scores and label idx_label_score += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), score.cpu().data.numpy().tolist())) # overall batch loss loss = torch.sum(rec_loss) / torch.sum(mask) epoch_loss += loss.item() n_batch += 1 print_progessbar(b, loader.__len__(), Name='\t\tBatch', Size=20) test_time = time.time() - start_time scores = idx_label_score return test_time, scores
def test(self, dataset, net): """ Test the joint DeepSVDD network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is tested. It must return an image and | semi-supervized labels. |---- net (nn.Module) The DeepSVDD to test. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- None """ logger = logging.getLogger() # make test dataloader using image and mask test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = True # define the two criterion for Anomaly detection and reconstruction criterion_rec = MaskedMSELoss(reduction='none') criterion_ad = self.SVDDLoss(self.space_repr, self.nu, eps=self.eps, soft_boundary=self.soft_boundary) # Testing logger.info('>>> Start Testing the joint DeepSVDD and AutoEncoder.') epoch_loss = 0.0 n_batch = 0 n_batch_tot = test_loader.__len__() start_time = time.time() idx_label_score_rec = [] idx_label_score_ad = [] # put network in evaluation mode net.eval() with torch.no_grad(): for b, data in enumerate(test_loader): input, label, mask, semi_label, idx = data # put data to device input, label = input.to(self.device).float(), label.to( self.device) mask, semi_label = mask.to(self.device), semi_label.to( self.device) idx = idx.to(self.device) # mask the input input = input * mask # compute loss rec, embed = net(input) loss_rec = criterion_rec(rec, input, mask) loss_ad = criterion_ad(embed, self.R) # compute anomaly scores rec_score = torch.mean( loss_rec, dim=tuple(range( 1, rec.dim()))) # mean over all dimension per batch if self.use_subspace: dist = torch.sum( (embed - torch.matmul(self.space_repr, embed.transpose( 0, 1)).transpose(0, 1))**2, dim=1 ) # score is the distance (large distances highlight anomalies) else: dist = torch.sum( (embed - self.space_repr)**2, dim=1 ) # score is the distance (large distances highlight anomalies) if self.soft_boundary: ad_score = dist - self.R**2 else: ad_score = dist # get overall loss mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask) loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec loss += self.scale_em * self.criterion_weight[1] * loss_ad # append scores and label idx_label_score_rec += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), rec_score.cpu().data.numpy().tolist())) idx_label_score_ad += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), ad_score.cpu().data.numpy().tolist())) epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) self.test_time = time.time() - start_time self.test_scores_rec = idx_label_score_rec _, label, rec_score = zip(*idx_label_score_rec) label, rec_score = np.array(label), np.array(rec_score) self.test_auc_rec = roc_auc_score(label, rec_score) self.test_f1_rec = f1_score( label, np.where(rec_score > self.scores_threhold_rec, 1, 0)) self.test_scores_ad = idx_label_score_ad _, label, ad_score = zip(*idx_label_score_ad) label, ad_score = np.array(label), np.array(ad_score) self.test_auc_ad = roc_auc_score(label, ad_score) self.test_f1_ad = f1_score( label, np.where(ad_score > self.scores_threhold_ad, 1, 0)) # add info to logger logger.info(f'>>> Test Time: {self.test_time:.3f} [s]') logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}') logger.info(f'>>> Test reconstruction AUC: {self.test_auc_rec:.3%}') logger.info( f'>>> Test F1-score on reconstruction score: {self.test_f1_rec:.3%}' ) logger.info(f'>>> Test AD AUC: {self.test_auc_ad:.3%}') logger.info( f'>>> Test F1-score on DeepSVDD score: {self.test_f1_ad:.3%}') logger.info( '>>> Finished Testing the Joint DeepSVDD and AutoEncoder.\n')
def train(self, dataset, net, valid_dataset=None): """ Train the joint DeepSVDD network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image, a mask and | semi-supervized labels. |---- net (nn.Module) The DeepSVDD to train. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- net (nn.Module) The trained joint DeepSVDD. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = True # enable the network to provide the SVDD embdeding # initialize hypersphere center or subspace projection matrix if self.space_repr is None: if self.use_subspace: logger.info('>>> Initializing the subspace projection matrix.') self.space_repr = self.initialize_projection_matrix( train_loader, net) logger.info('>>> Projection matrix succesfully initialized.') else: logger.info('>>> Initializing the hypersphere center.') self.space_repr = self.initialize_hypersphere_center( train_loader, net) logger.info('>>> Center succesfully initialized.') # define the two criterion for Anomaly detection and reconstruction criterion_rec = MaskedMSELoss() criterion_ad = self.SVDDLoss(self.space_repr, self.nu, eps=self.eps, soft_boundary=self.soft_boundary) # compute the scale weight so that the rec and svdd losses are scalled and comparable self.initialize_loss_scale_weight(train_loader, net, criterion_rec, criterion_ad) # define optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define scheduler scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestone, gamma=0.1) # Start training logger.info('>>> Start Training the Joint DeepSVDD and Autoencoder.') start_time = time.time() epoch_loss_list = [] n_batch_tot = train_loader.__len__() # set network in train mode net.train() for epoch in range(self.n_epoch): epoch_loss = 0.0 n_batch = 0 epoch_start_time = time.time() dist = [] for b, data in enumerate(train_loader): input, _, mask, semi_label, _ = data # put inputs to device input, mask, semi_label = input.to( self.device).float(), mask.to(self.device), semi_label.to( self.device) input.requires_grad = True # mask the input (keep only the object) input = input * mask # zeros the gradient optimizer.zero_grad() # Update network parameters by backpropagation on the two criterion rec, embed = net(input) # reconstruction loss # ignore reconstruction for known abnormal samples (no gradient update because loss = 0) rec = torch.where( semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1, rec, input) loss_rec = criterion_rec(rec, input, mask) loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec # SVDD embedding loss loss_ad = criterion_ad(embed, self.R) loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad loss = loss_rec + loss_ad loss.backward() optimizer.step() # compute dist to update radius R if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up): if self.use_subspace: dist.append( torch.sum((embed - torch.matmul( self.space_repr, embed.transpose( 0, 1)).transpose(0, 1))**2, dim=1).detach()) else: dist.append( torch.sum((self.space_repr - embed)**2, dim=1).detach()) epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) # update radius if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up): self.R.data = torch.tensor(self.get_radius( torch.cat(dist, dim=0)), device=self.device) valid_auc = '' if valid_dataset: auc_rec, auc_ad = self.validate(valid_dataset, net, final=False) net.train() valid_auc = f' Rec AUC: {auc_rec:.3%} | AD AUC: {auc_ad:.3%} | R {self.R:.3f} |' # epoch statistic epoch_train_time = time.time() - epoch_start_time logger.info( f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | Train Time: {epoch_train_time:.3f} [s] ' f'| Train Loss: {epoch_loss / n_batch:.6f} |' + valid_auc) # append the epoch loss to results list epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # update the learning rate if the milestone is reached scheduler.step() if epoch + 1 in self.lr_milestone: logger.info( f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}' ) # End training self.train_loss = epoch_loss_list self.train_time = time.time() - start_time logger.info( f'>>> Training of Joint DeepSVDD and AutoEncoder Time: {self.train_time:.3f} [s]' ) logger.info('>>> Finished Joint DeepSVDD and AutoEncoder Training.\n') return net
def pretrain(self, dataset, net): """ Pretrain the AE for the joint DeepSVDD network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image, a mask and | semi-supervized labels. |---- net (nn.Module) The DeepSVDD to train. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- net (nn.Module) The pretrained joint DeepSVDD. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = False # define the two criterion for reconstruction criterion_rec = MaskedMSELoss() # define optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Start training logger.info('>>> Start Pretraining the Autoencoder.') start_time = time.time() epoch_loss_list = [] n_batch_tot = train_loader.__len__() # set network in train mode net.train() for epoch in range(self.n_epoch_pretrain): epoch_loss = 0.0 n_batch = 0 epoch_start_time = time.time() for b, data in enumerate(train_loader): input, _, mask, semi_label, _ = data # put inputs to device input, mask, semi_label = input.to( self.device).float(), mask.to(self.device), semi_label.to( self.device) input.requires_grad = True # mask the input (keep only the object) input = input * mask # zeros the gradient optimizer.zero_grad() # Update network parameters by backpropagation rec, _ = net(input) # ignore reconstruction for known abnormal samples (no gradient update because loss = 0) rec = torch.where( semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1, rec, input) loss = criterion_rec(rec, input, mask) loss.backward() optimizer.step() epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) # epoch statistic epoch_train_time = time.time() - epoch_start_time logger.info( f'| Epoch: {epoch + 1:03}/{self.n_epoch_pretrain:03} | Pretrain Time: {epoch_train_time:.3f} [s] ' f'| Pretrain Loss: {epoch_loss / n_batch:.6f} |') epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # End training self.pretrain_loss = epoch_loss_list self.pretrain_time = time.time() - start_time logger.info( f'>>> Pretraining of AutoEncoder Time: {self.pretrain_time:.3f} [s]' ) logger.info('>>> Finished of AutoEncoder Pretraining.\n') return net
def validate(self, dataset, net): """ Validate the joint DMSVDD network on the provided dataset and find the best threshold on the score to maximize the f1-score. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is validated. It must return an image and | semi-supervized labels. |---- net (nn.Module) The DMSVDD to validate. The network should be | an autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- None """ logger = logging.getLogger() # make test dataloader using image and mask valid_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = True # define the two criterion for Anomaly detection and reconstruction criterion_rec = MaskedMSELoss(reduction='none') criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary) # Testing logger.info('>>> Start Validating of the joint DMSVDD and AutoEncoder.') epoch_loss = 0.0 n_batch = 0 n_batch_tot = valid_loader.__len__() start_time = time.time() idx_label_score_rec = [] idx_label_score_ad = [] # put network in evaluation mode net.eval() with torch.no_grad(): for b, data in enumerate(valid_loader): input, label, mask, semi_label, idx = data # put data to device input, label = input.to(self.device).float(), label.to(self.device) mask, semi_label = mask.to(self.device), semi_label.to(self.device) idx = idx.to(self.device) # mask the input input = input * mask # compute loss rec, embed = net(input) loss_rec = criterion_rec(rec, input, mask) loss_ad = criterion_ad(embed, self.c, self.R) # compute anomaly scores rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch dist, idx = torch.min(torch.sum((self.c.unsqueeze(0) - embed.unsqueeze(1))**2, dim=2), dim=1) # dist and idx by batch if self.soft_boundary: ad_score = dist - torch.stack([self.R[i] ** 2 for i in idx], dim=0) #dist - self.R ** 2 --> negative = normal ; positive = abnormal else: ad_score = dist # compute overall loss mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask) loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec loss += self.scale_em * self.criterion_weight[1] * loss_ad # append scores and label idx_label_score_rec += list(zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), rec_score.cpu().data.numpy().tolist())) idx_label_score_ad += list(zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), ad_score.cpu().data.numpy().tolist())) epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) self.valid_time = time.time() - start_time self.valid_scores_rec = idx_label_score_rec _, label, rec_score = zip(*idx_label_score_rec) label, rec_score = np.array(label), np.array(rec_score) self.valid_auc_rec = roc_auc_score(label, rec_score) self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(rec_score, label, metric=f1_score) self.valid_scores_ad = idx_label_score_ad _, label, ad_score = zip(*idx_label_score_ad) label, ad_score = np.array(label), np.array(ad_score) self.valid_auc_ad = roc_auc_score(label, ad_score) self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(ad_score, label, metric=f1_score) # add info to logger logger.info(f'>>> Validation Time: {self.valid_time:.3f} [s]') logger.info(f'>>> Validation Loss: {epoch_loss / n_batch:.6f}') logger.info(f'>>> Validation reconstruction AUC: {self.valid_auc_rec:.3%}') logger.info(f'>>> Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}') logger.info(f'>>> Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}') logger.info(f'>>> Validation DMSVDD AUC: {self.valid_auc_ad:.3%}') logger.info(f'>>> Best Threshold for the DMSVDD score maximizing F1-score: {self.scores_threhold_ad:.3f}') logger.info(f'>>> Best F1-score on DMSVDD score: {self.valid_f1_ad:.3%}') logger.info('>>> Finished validating the Joint DMSVDD and AutoEncoder.\n')
def evaluate(self, net, dataset, mode='test', final=False): """ Evaluate the model with the given dataset. ---------- INPUT |---- net (nn.Module) The DMSAD to validate. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is validated. It must return an image, a mask and | semi-supervized labels. |---- mode (str) either 'valid' or 'test'. Define the evaluation mode. | In 'valid' the evaluation can return the reconstruction | and MSAD AUCs and compute the best threshold to maximize | the F1-scores. In test mode the validation threshold is | used to compute the F1-score. |---- final (bool) whether the call represents the final validation, | in which case the validation results are saved. Only | relevant if mode is 'valid'. OUTPUT |---- auc (tuple (reconstruction auc, ad auc)) the validation AUC for | both scores are return only if final is False. Else None | is return. """ assert mode in [ 'valid', 'test' ], f'Mode {mode} is not supported. Should be either "valid" or "test".' logger = logging.getLogger() # make the dataloader loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs_dataloader) # put net on device net = net.to(self.device) # set the network to provide both the reconstruction and the embedding net.return_svdd_embed = True # define the two loss function loss_fn_rec = MaskedMSELoss( reduction='none' ) # no reduction to compute AD score for each sample loss_fn_ad = DMSADLoss(self.eta, self.eps) # Validate if final or mode == 'test': logger.info( f' Start Evaluating the jointly trained DMSAD and AutoEncoder in {mode} mode.' ) epoch_loss = 0.0 n_batch = len(loader) start_time = time.time() idx_label_score_rec, idx_label_score_ad = [], [ ] # placeholder for scores net.eval() with torch.no_grad(): for b, data in enumerate(loader): # get data on device input, label, mask, semi_label, idx = data input = input.to(self.device).float() label = label.to(self.device) mask = mask.to(self.device) semi_label = semi_label.to(self.device) idx = idx.to(self.device) # mask input input = input * mask # compute the loss rec, embed = net(input) loss_rec = loss_fn_rec(rec, input, mask) loss_ad = loss_fn_ad(embed, self.c, semi_label) # get reconstruction anomaly scores : mean loss by sample rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # find closest sphere dist, sphere_idx = torch.min(torch.norm(self.c.unsqueeze(0) - embed.unsqueeze(1), p=2, dim=2), dim=1) if not self.R is None: # anomaly scores positive if dist > R and negative if dist < R ad_score = dist - torch.stack( [self.R[j] for j in sphere_idx], dim=0) else: # else scores is just the minimal distance to a center ad_score = dist # append scores to the placeholer lists idx_label_score_rec += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), rec_score.cpu().data.numpy().tolist())) idx_label_score_ad += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), ad_score.cpu().data.numpy().tolist(), sphere_idx.cpu().data.numpy().tolist(), embed.cpu().data.numpy().tolist())) # compute the overall loss loss = self.scale_rec * self.criterion_weight[0] * ( torch.sum(loss_rec) / torch.sum(mask)) loss += self.scale_ad * self.criterion_weight[1] * loss_ad epoch_loss += loss.item() if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\t Evaluation Batch', Size=40, erase=True) # compute AUCs _, label, rec_score = zip(*idx_label_score_rec) label, rec_score = np.array(label), np.array(rec_score) auc_rec = roc_auc_score(label, rec_score) _, label, ad_score, _, _ = zip(*idx_label_score_ad) label, ad_score = np.array(label), np.array(ad_score) auc_ad = roc_auc_score(label, ad_score) if mode == 'valid': if final: # save results self.valid_time = time.time() - start_time self.valid_scores_rec = idx_label_score_rec self.valid_auc_rec = auc_rec self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold( rec_score, label, metric=f1_score) self.valid_scores_ad = idx_label_score_ad self.valid_auc_ad = auc_ad self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold( ad_score, label, metric=f1_score) # print infos logger.info(f'---- Validation Time: {self.valid_time:.3f} [s]') logger.info( f'---- Validation Loss: {epoch_loss / n_batch:.6f}') logger.info( f'---- Validation reconstruction AUC: {self.valid_auc_rec:.3%}' ) logger.info( f'---- Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}' ) logger.info( f'---- Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}' ) logger.info( f'---- Validation MSAD AUC: {self.valid_auc_ad:.3%}') logger.info( f'---- Best Threshold for the MSAD score maximizing F1-score: {self.scores_threhold_ad:.3f}' ) logger.info( f'---- Best F1-score on MSAD score: {self.valid_f1_ad:.3%}' ) logger.info( '---- Finished validating the Joint DMSAD and AutoEncoder.\n' ) else: return auc_rec, auc_ad elif mode == 'test': # save results self.test_time = time.time() - start_time self.test_scores_rec = idx_label_score_rec self.test_auc_rec = auc_rec self.test_scores_ad = idx_label_score_ad self.test_auc_ad = auc_ad # print infos logger.info(f'---- Test Time: {self.test_time:.3f} [s]') logger.info(f'---- Test Loss: {epoch_loss / n_batch:.6f}') logger.info( f'---- Test reconstruction AUC: {self.test_auc_rec:.3%}') if self.scores_threhold_rec is not None: self.test_f1_rec = f1_score( label, np.where(rec_score > self.scores_threhold_rec, 1, 0)) logger.info( f'---- Best F1-score on reconstruction score: {self.test_f1_rec:.3%}' ) logger.info(f'---- Test MSAD AUC: {self.test_auc_ad:.3%}') if self.scores_threhold_ad is not None: self.test_f1_ad = f1_score( label, np.where(ad_score > self.scores_threhold_ad, 1, 0)) logger.info( f'---- Best F1-score on MSAD score: {self.test_f1_ad:.3%}') logger.info( '---- Finished testing the Joint DMSAD and AutoEncoder.\n')
def train(self, dataset, net): """ Train the DMSVDD on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image, a mask and | semi-supervized labels. |---- net (nn.Module) The DMSVDD to train. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- net (nn.Module) The pretrained joint DMSVDD. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = True # enable the network to provide the SVDD embdeding # initialize hypersphere center or subspace projection matrix if self.c is None: logger.info('>>> Initializing the hyperspheres centers.') self.initialize_centers(train_loader, net) logger.info(f'>>> {self.n_sphere_init} centers succesfully initialized.') # define the two criterion for Anomaly detection and reconstruction criterion_rec = MaskedMSELoss() criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary) # compute the scale weight so that the rec and svdd losses are scalled and comparable self.initialize_loss_scale_weight(train_loader, net, criterion_rec, criterion_ad) # define optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define scheduler scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestone, gamma=0.1) # Start training logger.info('>>> Start Training the Joint DMSVDD and Autoencoder.') start_time = time.time() epoch_loss_list = [] n_batch_tot = train_loader.__len__() # set network in train mode net.train() for epoch in range(self.n_epoch): epoch_loss = 0.0 n_batch = 0 epoch_start_time = time.time() # update network for b, data in enumerate(train_loader): input, _, mask, semi_label, _ = data # put inputs to device input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device) input.requires_grad = True # mask the input (keep only the object) input = input * mask # zeros the gradient optimizer.zero_grad() # Update network parameters by backpropagation on the two criterion rec, embed = net(input) # reconstruction loss # ignore reconstruction for known abnormal samples (no gradient update because loss = 0) rec = torch.where(semi_label.view(-1,1,1,1).expand(*input.shape) != -1, rec, input) loss_rec = criterion_rec(rec, input, mask) loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec # SVDD embedding loss loss_ad = criterion_ad(embed, self.c, self.R) loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad loss = loss_rec + loss_ad loss.backward() optimizer.step() epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tWeight-update Batch', Size=20) with torch.no_grad(): # update radius R if epoch >= self.n_epoch_warm_up: n_k = torch.zeros(self.c.shape[0], device=self.device) dist = [[] for _ in range(self.c.shape[0])] # list of list for each center : N_center x N_k for b, data in enumerate(train_loader): # compute distance and belonging of sample input, _, mask, semi_label, _ = data input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device) # mask the input (keep only the object) input = (input * mask)[semi_label != -1] _, embed = net(input) # get closest centers min_dist, idx = torch.min(torch.norm(self.c.unsqueeze(0) - embed.unsqueeze(1), p=2, dim=2), dim=1) for i, d in zip(idx, min_dist): n_k[i] += 1 dist[i].append(d) if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tRadius-update Batch', Size=20) if self.soft_boundary: # update R with (1-nu)th quantile self.R = torch.where(n_k < self.nu * torch.max(n_k), torch.Tensor([0.0]).to(self.device), torch.Tensor([np.quantile(torch.stack(d, dim=0).clone().cpu().numpy(), 1 - self.nu) if len(d) > 0 else 0.0 for d in dist]).to(self.device)) # keep only centers and radius where R > 0 self.c = self.c[self.R > 0.0] self.R = self.R[self.R > 0.0] else: # keep only centers that are not represented self.c = self.c[n_k == 0] #self.c = self.c[n_k < self.nu * torch.max(n_k)] # epoch statistic epoch_train_time = time.time() - epoch_start_time logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} ' f'| Train Time: {epoch_train_time:.3f} [s] ' f'| Train Loss: {epoch_loss / n_batch:.6f} ' f'| N spheres: {self.c.shape[0]:03} |') # append the epoch loss to results list epoch_loss_list.append([epoch+1, epoch_loss/n_batch]) # update the learning rate if the milestone is reached scheduler.step() if epoch + 1 in self.lr_milestone: logger.info(f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}') # End training self.train_loss = epoch_loss_list self.train_time = time.time() - start_time logger.info(f'>>> Training of Joint DMSVDD and AutoEncoder Time: {self.train_time:.3f} [s]') logger.info('>>> Finished Joint DMSVDD and AutoEncoder Training.\n') return net
def pretrain(self, net, dataset): """ Pretrain the AE for the joint DMSAD network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is pretrained. It must return a tuple (image, | label, mask, semi-supervized labels, idx). |---- net (nn.Module) The DMSAD to pretrain. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- net (nn.Module) The pretrained joint DMSAD. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs_dataloader) # put net on device net = net.to(self.device) # set the network to provide only the reconstruction net.return_svdd_embed = False # define the reconstruvtion loss function loss_fn_rec = MaskedMSELoss() # define the optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Start training logger.info(' Start Pretraining the Autoencoder.') start_time = time.time() epoch_loss_list = [] n_batch = train_loader.__len__() for epoch in range(self.n_epoch_pretrain): net.train() epoch_loss = 0.0 epoch_start_time = time.time() for b, data in enumerate(train_loader): # get batch data input, _, mask, semi_label, _ = data input = input.to(self.device).float().requires_grad_(True) mask = mask.to(self.device) semi_label = semi_label.to(self.device) # mask the input and keep only normal samples input = (input * mask)[semi_label != -1] mask = mask[semi_label != -1] # Update network parameters via backpropagation : forward + backward + optim optimizer.zero_grad() rec, _ = net(input) loss = loss_fn_rec(rec, input, mask) loss.backward() optimizer.step() epoch_loss += loss.item() if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\tBatch', Size=40, erase=True) # print epoch statstics logger.info( f'----| Epoch {epoch + 1:03}/{self.n_epoch_pretrain:03} ' f'| Pretrain Time {time.time() - epoch_start_time:.3f} [s] ' f'| Pretrain Loss {epoch_loss / n_batch:.6f} |') # store loss epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # End training self.pretrain_loss = epoch_loss_list self.pretrain_time = time.time() - start_time logger.info( f'---- Finished Pretraining the AutoEncoder in {self.pretrain_time:.3f} [s].' ) return net
def train(self, net, dataset, valid_dataset=None): """ Train the DMSAD on the provided dataset. ---------- INPUT |---- net (nn.Module) The DMSAD to train. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image, a mask and | semi-supervized labels. |---- valid_dataset (torch.utils.data.Dataset) the dataset on which | to validate the model at each epoch. No validation is | performed if not provided. OUTPUT |---- net (nn.Module) The pretrained joint DMSAD. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs_dataloader) # put net on device net = net.to(self.device) # set the network to provide both the reconstruction and the embedding net.return_svdd_embed = True # Initialize the hyper-sphere centers by Kmeans if self.c is None: logger.info(' Initializing the hypersheres centers.') self.initialize_centers(train_loader, net) logger.info( f' {self.c.shape[0]} centers successfully initialized.') # define the reconstruvtion loss function loss_fn_rec = MaskedMSELoss() loss_fn_ad = DMSADLoss(self.eta, eps=self.eps) # Compute the scaling factors for the reconstruction and DMSAD losses logger.info(' Initializing the loss scale factors.') self.initialize_loss_scale_weight(train_loader, net, loss_fn_rec, loss_fn_ad) logger.info( f' reconstruction loss scale factor initialized to {self.scale_rec:.6f}' ) logger.info( f' MSAD embdeding loss scale factor initialized to {self.scale_ad:.6f}' ) # define the optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define the learning rate scheduler : 90% reduction at each steps scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestone, gamma=0.1) # Start training logger.info(' Start Training Jointly the DMSAD and the Autoencoder.') start_time = time.time() epoch_loss_list = [] n_batch = len(train_loader) for epoch in range(self.n_epoch): net.train() epoch_loss = 0.0 epoch_start_time = time.time() n_k = torch.zeros(self.c.shape[0], device=self.device) for b, data in enumerate(train_loader): # get batch data input, _, mask, semi_label, _ = data input = input.to(self.device).float().requires_grad_(True) mask = mask.to(self.device) semi_label = semi_label.to(self.device) # mask input input = input * mask # Update the network by backpropagation using the two losses. optimizer.zero_grad() rec, embed = net(input) # reconstruction loss only on normal sample (loss of zero for abnormal) rec = torch.where( semi_label.view(-1, 1, 1, 1).expand(*input.shape) != 1, rec, input) loss_rec = self.scale_rec * self.criterion_weight[ 0] * loss_fn_rec(rec, input, mask) # DMSAD loss loss_ad = self.scale_ad * self.criterion_weight[ 1] * loss_fn_ad(embed, self.c, semi_label) # total loss loss = loss_rec + loss_ad loss.backward() optimizer.step() epoch_loss += loss.item() # get the closest sphere and count the number of normal samples per sphere idx = torch.argmin(torch.norm(self.c.unsqueeze(0) - embed.unsqueeze(1), p=2, dim=2), dim=1) for i in idx[semi_label != -1]: n_k[i] += 1 if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) # remove centers with less than gamma fraction of largest hypersphere number of sample self.c = self.c[n_k >= self.gamma * torch.max(n_k)] # intermediate validation of the model if required valid_auc = '' if valid_dataset: auc_rec, auc_ad = self.evaluate(net, valid_dataset, mode='valid', final=False) valid_auc = f' Rec AUC {auc_rec:.3%} | MSAD AUC {auc_ad:.3%} |' # print epoch statstics logger.info( f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} ' f'| Train Time {time.time() - epoch_start_time:.3f} [s] ' f'| Train Loss {epoch_loss / n_batch:.6f} ' f'| N sphere {self.c.shape[0]:03} |' + valid_auc) # store loss epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # update learning rate if milestone is reached scheduler.step() if epoch + 1 in self.lr_milestone: logger.info( f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}' ) # re-initialized loss scale factors after few epochs when the centers are more or less defined if epoch + 1 == self.reset_scaling_epoch: with torch.no_grad(): # Compute the scaling factors for the reconstruction and DMSAD losses logger.info('---- Reinitializing the loss scale factors.') self.initialize_loss_scale_weight(train_loader, net, loss_fn_rec, loss_fn_ad) logger.info( f'---- reconstruction loss scale factor reinitialized to {self.scale_rec:.6f}' ) logger.info( f'---- MSAD embdeding loss scale factor reinitialized to {self.scale_ad:.6f}' ) # Set the radius of each sphere as 1-gamma quantile of normal samples distances logger.info( f'---- Setting the hyperspheres radii as the {1-self.gamma:.1%} quantiles of normal sample distances.' ) self.set_radius(train_loader, net) logger.info(f'---- {self.R.shape[0]} radii successufully defined.') # End Training self.train_loss = epoch_loss_list self.train_time = time.time() - start_time logger.info( f'---- Finished jointly training the DMSAD and the Autoencoder in {self.train_time:.3f} [s].' ) return net
def train(self, dataset, net, valid_dataset=None): """ Train the ARAE network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image, a mask and | semi-supervised labels. |---- net (nn.Module) The ARAE to train. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- net (nn.Module) The trained ARAE. """ logger = logging.getLogger() # make dataloader train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) # define the criterions criterion_rec = MaskedMSELoss() criterion_lat = nn.MSELoss() # define optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # define scheduler scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestone, gamma=0.1) # Start training logger.info('>>> Start Training the ARAE.') start_time = time.time() epoch_loss_list = [] n_batch_tot = train_loader.__len__() # set network in train mode net.train() for epoch in range(self.n_epoch): epoch_loss = 0.0 n_batch = 0 epoch_start_time = time.time() for b, data in enumerate(train_loader): input, _, mask, semi_label, _ = data input = input.to(self.device).float().requires_grad_(True) semi_label = semi_label.to(self.device) mask = mask.to(self.device) # mask input input = input * mask if self.use_PGD: adv_input = self.adversarial_search(input, net) else: adv_input = self.FGSM(input, net) # pass the adversarial and normal samples through the network net.encoding_only = True _, lat = net(input) net.encoding_only = False rec_adv, lat_adv = net(adv_input) # compute the loss loss_rec = criterion_rec(adv_input, rec_adv, mask) loss_lat = criterion_lat(lat, lat_adv) loss = loss_rec + self.gamma * loss_lat optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) valid_auc = '' if valid_dataset: auc = self.validate(valid_dataset, net, final=False) net.train() valid_auc = f' Rec AUC {auc:.3%} |' # epoch statistic epoch_train_time = time.time() - epoch_start_time logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} ' f'| Train Time: {epoch_train_time:.3f} [s] ' f'| Train Loss: {epoch_loss / n_batch:.6f} |' + valid_auc) # append the epoch loss to results list epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # update the learning rate if the milestone is reached scheduler.step() if epoch + 1 in self.lr_milestone: logger.info( f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}' ) # End training self.train_loss = epoch_loss_list self.train_time = time.time() - start_time logger.info(f'>>> Training Time of ARAE: {self.train_time:.3f} [s]') logger.info('>>> Finished ARAE Training.\n') return net
def evaluate(self, net, dataset, mode='test', final='False'): """ Evaluate the model with the given dataset. ---------- INPUT |---- net (nn.Module) The DSAD to validate. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is validated. It must return an image, a mask and | semi-supervized labels. |---- mode (str) either 'valid' or 'test'. Define the evaluation mode. | In 'valid' the evaluation can return the reconstruction | and SAD AUCs and compute the best threshold to maximize | the F1-scores. In test mode the validation threshold is | used to compute the F1-score. |---- final (bool) whether the call represents the final validation, | in which case the validation results are saved. Only | relevant if mode is 'valid'. OUTPUT |---- auc (tuple (reconstruction auc, ad auc)) the validation AUC for | both scores are return only if final is False. Else None | is return. """ assert mode in ['valid','test'], f'Mode {mode} is not supported. Should be either "valid" or "test".' logger = logging.getLogger() # make test dataloader using image and mask loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) net.return_svdd_embed = True # define the two criterion for Anomaly detection and reconstruction criterion_rec = MaskedMSELoss(reduction='none') criterion_ad = self.SADLoss(self.space_repr, self.eta, eps=self.eps) # Testing if final or mode == 'test': logger.info(f' Start Evaluating the jointly trained DSAD and AutoEncoder in {mode} mode.') epoch_loss = 0.0 n_batch = len(loader) start_time = time.time() idx_label_score_rec, idx_label_score_ad = [], [] net.eval() with torch.no_grad(): for b, data in enumerate(loader): input, label, mask, semi_label, idx = data # put data to device input, label = input.to(self.device).float(), label.to(self.device) mask, semi_label = mask.to(self.device), semi_label.to(self.device) idx = idx.to(self.device) # mask the input input = input * mask # compute loss rec, embed = net(input) loss_rec = criterion_rec(rec, input, mask) loss_ad = criterion_ad(embed, semi_label) # compute anomaly scores rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch #rec_score = torch.sum(loss_rec, dim=tuple(range(1, rec.dim()))) / (torch.sum(mask, dim=tuple(range(1, rec.dim()))) + 1) # mean reconstruction MSE on the mask per batch if self.use_subspace: ad_score = torch.sum((embed - torch.matmul(self.space_repr, embed.transpose(0,1)).transpose(0,1)) ** 2, dim=1) # score is the distance (large distances highlight anomalies) else: ad_score = torch.sum((embed - self.space_repr) ** 2, dim=1) # score is the distance (large distances highlight anomalies) # compute overall loss mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask) loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec loss += self.scale_em * self.criterion_weight[1] * loss_ad # append scores and label idx_label_score_rec += list(zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), rec_score.cpu().data.numpy().tolist())) idx_label_score_ad += list(zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), ad_score.cpu().data.numpy().tolist())) epoch_loss += loss.item() if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\tBatch', Size=40, erase=True) # compute AUCs _, label, rec_score = zip(*idx_label_score_rec) label, rec_score = np.array(label), np.array(rec_score) auc_rec = roc_auc_score(label, rec_score) _, label, ad_score = zip(*idx_label_score_ad) label, ad_score = np.array(label), np.array(ad_score) auc_ad = roc_auc_score(label, ad_score) if mode == 'valid': if final: self.valid_time = time.time() - start_time self.valid_scores_rec = auc_rec self.valid_auc_rec = roc_auc_score(label, rec_score) self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(rec_score, label, metric=f1_score) self.valid_scores_ad = idx_label_score_ad self.valid_auc_ad = auc_ad self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(ad_score, label, metric=f1_score) # add info to logger logger.info(f'---- Validation Time: {self.valid_time:.3f} [s]') logger.info(f'---- Validation Loss: {epoch_loss / n_batch:.6f}') logger.info(f'---- Validation reconstruction AUC: {self.valid_auc_rec:.3%}') logger.info(f'---- Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}') logger.info(f'---- Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}') logger.info(f'---- Validation SAD AUC: {self.valid_auc_ad:.3%}') logger.info(f'---- Best Threshold for the MSAD score maximizing F1-score: {self.scores_threhold_ad:.3f}') logger.info(f'---- Best F1-score on SAD score: {self.valid_f1_ad:.3%}') logger.info('---- Finished validating the Joint DSAD and AutoEncoder.\n') else: return auc_rec, auc_ad elif mode == 'test': # save results self.test_time = time.time() - start_time self.test_scores_rec = idx_label_score_rec self.test_auc_rec = auc_rec self.test_scores_ad = idx_label_score_ad self.test_auc_ad = auc_ad # print infos logger.info(f'---- Test Time: {self.test_time:.3f} [s]') logger.info(f'---- Test Loss: {epoch_loss / n_batch:.6f}') logger.info(f'---- Test reconstruction AUC: {self.test_auc_rec:.3%}') if self.scores_threhold_rec is not None: self.test_f1_rec = f1_score(label, np.where(rec_score > self.scores_threhold_rec, 1, 0)) logger.info(f'---- Best F1-score on reconstruction score: {self.test_f1_rec:.3%}') logger.info(f'---- Test SAD AUC: {self.test_auc_ad:.3%}') if self.scores_threhold_ad is not None: self.test_f1_ad = f1_score(label, np.where(ad_score > self.scores_threhold_ad, 1, 0)) logger.info(f'---- Best F1-score on SAD score: {self.test_f1_ad:.3%}') logger.info('---- Finished testing the Joint DSAD and AutoEncoder.\n')
def test(self, dataset, net): """ Test the ARAE network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is tested. It must return an image, a mask and | semi-supervised labels. |---- net (nn.Module) The ARAE to test. The network should be an | autoencoder for which the forward pass returns both the | reconstruction and the embedding of the input. OUTPUT |---- None """ logger = logging.getLogger() # make test dataloader using image and mask test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # put net to device net = net.to(self.device) # loss function criterion = MaskedMSELoss(reduction='none') # Testing logger.info('>>> Start Testing of the ARAE.') epoch_loss = 0.0 n_batch = 0 n_batch_tot = test_loader.__len__() start_time = time.time() idx_label_score = [] net.eval() with torch.no_grad(): for b, data in enumerate(test_loader): input, label, mask, _, idx = data # put data to device input = input.to(self.device).float() label = label.to(self.device).float() mask = mask.to(self.device) idx = idx.to(self.device) # mask input input = input * mask rec, _ = net(input) loss = criterion(rec, input, mask) ad_score = torch.mean(loss, dim=tuple(range( 1, rec.dim()))) # mean loss over batch idx_label_score += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), ad_score.cpu().data.numpy().tolist())) # compute the mean reconstruction loss loss = torch.sum(loss) / torch.sum(mask) epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20) self.test_time = time.time() - start_time self.test_scores = idx_label_score _, label, ad_score = zip(*idx_label_score) label, ad_score = np.array(label), np.array(ad_score) self.test_auc = roc_auc_score(label, ad_score) self.test_f1 = f1_score( label, np.where(ad_score > self.scores_threshold, 1, 0)) # add info to logger logger.info(f'>>> Testing Time: {self.test_time:.3f} [s]') logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}') logger.info(f'>>> Test AUC: {self.test_auc:.3%}') logger.info(f'>>> Test F1-score: {self.test_f1:.3%}') logger.info('>>> Finished testing the ARAE.\n')
def train(self, dataset, ae_net): """ Train the autoencoder network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image and a mask | of where the loss is to be computed. |---- ae_net (nn.Module) The autoencoder to train. OUTPUT |---- ae_net (nn.Module) The trained autoencoder. """ logger = logging.getLogger() # make train dataloader using image and mask train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # MSE loss without reduction --> MSE loss for each output pixels criterion = MaskedMSELoss() # set to device ae_net = ae_net.to(self.device) criterion = criterion.to(self.device) # set optimizer optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # set the learning rate scheduler (multiple phase learning) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestone, gamma=0.1) # Training logger.info('>>> Start Training the AutoEncoder.') start_time = time.time() epoch_loss_list = [] # set the network in train mode ae_net.train() for epoch in range(self.n_epoch): epoch_loss = 0.0 n_batch = 0 epoch_start_time = time.time() for b, data in enumerate(train_loader): input, _, mask, _, _ = data # put inputs to device input, mask = input.to(self.device).float(), mask.to(self.device) # zero the network gradients optimizer.zero_grad() # Update network paramters by backpropagation by considering only the loss on the mask rec = ae_net(input) loss = criterion(rec, input, mask) loss.backward() optimizer.step() epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, train_loader.__len__(), Name='\t\tBatch', Size=20) # epoch statistic epoch_train_time = time.time() - epoch_start_time logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | Train Time: {epoch_train_time:.3f} [s] ' f'| Train Loss: {epoch_loss / n_batch:.6f} |') epoch_loss_list += [[epoch+1, epoch_loss/n_batch]] # apply the scheduler step scheduler.step() if epoch in self.lr_milestone: logger.info('>>> LR Scheduler : new learning rate %g' % float(scheduler.get_lr()[0])) # End training self.train_loss = epoch_loss_list self.train_time = time.time() - start_time logger.info(f'>>> Training of AutoEncoder Time: {self.train_time:.3f} [s]') logger.info('>>> Finished AutoEncoder Training.\n') return ae_net
def test(self, dataset, ae_net): """ Test the autoencoder network on the provided dataset. ---------- INPUT |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is tested. It must return an image and a mask | of where the loss is to be computed. |---- ae_net (nn.Module) The autoencoder network to test. OUTPUT |---- None """ logger = logging.getLogger() # make test dataloader using image and mask test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_jobs_dataloader) # MSE loss without reduction --> MSE loss for each output pixels criterion = MaskedMSELoss(reduction='none') # set to device ae_net = ae_net.to(self.device) criterion = criterion.to(self.device) # Testing logger.info('>>> Start Testing the AutoEncoder.') epoch_loss = 0.0 n_batch = 0 start_time = time.time() idx_label_score = [] # put network in evaluation mode ae_net.eval() with torch.no_grad(): for b, data in enumerate(test_loader): input, label, mask, _, idx = data # put inputs to device input, label = input.to(self.device).float(), label.to(self.device) mask, idx = mask.to(self.device), idx.to(self.device) rec = ae_net(input) rec_loss = criterion(rec, input, mask) score = torch.mean(rec_loss, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch # append scores and label idx_label_score += list(zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), score.cpu().data.numpy().tolist())) # overall batch loss loss = torch.sum(rec_loss) / torch.sum(mask) epoch_loss += loss.item() n_batch += 1 if self.print_batch_progress: print_progessbar(b, test_loader.__len__(), Name='\t\tBatch', Size=20) self.test_time = time.time() - start_time self.test_scores = idx_label_score # Compute AUC : if AE is good a high reconstruction loss highlights the presence of an anomaly on the image _, label, score = zip(*idx_label_score) label, score = np.array(label), np.array(score) self.test_auc = roc_auc_score(label, score) self.test_f1 = f1_score(label, np.where(score > self.scores_threhold, 1, 0)) # add info to logger logger.info(f'>>> Test Time: {self.test_time:.3f} [s]') logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}') logger.info(f'>>> Test AUC: {self.test_auc:.3%}') logger.info(f'>>> Test F1-score: {self.test_f1:.3%}') logger.info('>>> Finished Testing the AutoEncoder.\n')
def train(self, net, dataset, valid_dataset=None): """ Train the autoencoder network on the provided dataset. ---------- INPUT |---- net (nn.Module) The autoencoder to train. It must return two | embedding (after the convolution and after the MLP) as | well as the reconstruction |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is trained. It must return an image , the label, | a mask, semi-supervised label and the index. |---- valid_dataset (torch.utils.data.Dataset) the optional dataset | on which to validate the model at each epoch. OUTPUT |---- net (nn.Module) The trained autoencoder. """ logger = logging.getLogger() # make train dataloader using image and mask train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ shuffle=True, num_workers=self.n_job_dataloader) # define loss_fn loss_fn = MaskedMSELoss() # set network on device net = net.to(self.device) # define optimizer optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # set the learning rate scheduler (multiple phase learning) scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestone, gamma=0.1) # Training logger.info('Start Training AE.') start_time = time.time() epoch_loss_list = [] n_batch = len(train_loader) for epoch in range(self.n_epoch): epoch_loss = 0.0 epoch_start_time = time.time() net.train() for b, data in enumerate(train_loader): input, _, mask, semi_label, _ = data # put inputs to device input = input.to(self.device).float().requires_grad_(True) mask = mask.to(self.device) semi_label = semi_label.to(self.device) # keep only input that are normal #input = input[semi_label != -1] #mask = mask[semi_label != -1] # mask input input = input * mask # zero the network gradients optimizer.zero_grad() # Update network paramters by backpropagation by considering only the loss on the mask _, _, rec = net(input) loss = loss_fn(rec, input, mask) loss.backward() optimizer.step() epoch_loss += loss.item() if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\tTrain Batch', Size=40, erase=True) valid_auc = '' if valid_dataset: auc = self.evaluate(net, valid_dataset, save_tSNE=False, return_auc=True, print_to_logger=False) valid_auc = f' Valid AUC {auc:.6f} |' # display epoch statistics logger.info(f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} ' f'| Time {time.time() - epoch_start_time:.3f} [s]' f'| Loss {epoch_loss / n_batch:.6f} |' + valid_auc) # store loss epoch_loss_list.append([epoch + 1, epoch_loss / n_batch]) # update learning rate if milestone is reached scheduler.step() if epoch + 1 in self.lr_milestone: logger.info( f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}' ) # Save results self.train_time = time.time() - start_time self.train_loss = epoch_loss_list logger.info(f'---- Finished Training AE in {self.train_time:.3f} [s].') return net
def evaluate(self, net, dataset, print_to_logger=True, return_auc=False, save_tSNE=True): """ Evaluate the natwork on the provided dataset. ---------- INPUT |---- net (nn.Module) The autoencoder to train. It must return two | embedding (after the convolution and after the MLP) as | well as the reconstruction |---- dataset (torch.utils.data.Dataset) the dataset on which the | network is validated. It must return an image and a mask | of where the loss is to be computed. |---- print_to_logger (bool) whether to print info in logger. |---- return_auc (bool) whether to return the computed AUC. |---- save_tSNE (bool) whether to save the intermediate representation | as a 2D vector using tSNE. OUTPUT |---- None """ if print_to_logger: logger = logging.getLogger() # make dataloader (with drop_last = True to ensure that the loss can be computed) loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.n_job_dataloader) # put net on device net = net.to(self.device) # define loss function loss_fn = MaskedMSELoss(reduction='none') if print_to_logger: logger.info("Start Evaluating AE.") idx_label_scores = [] n_batch = len(loader) net.eval() with torch.no_grad(): for b, data in enumerate(loader): input, label, mask, semi_label, idx = data # put inputs to device input = input.to(self.device).float().requires_grad_(True) label = label.to(self.device) mask = mask.to(self.device) semi_label = semi_label.to(self.device) idx = idx.to(self.device) # mask input input = input * mask h, z, rec = net(input) # compute score as mean loss over by sample rec_loss = loss_fn(rec, input, mask) score = torch.mean(rec_loss, dim=tuple(range(1, rec.dim()))) # append scores : idx label score h z idx_label_scores += list( zip(idx.cpu().data.numpy().tolist(), label.cpu().data.numpy().tolist(), score.cpu().data.numpy().tolist(), h.cpu().data.numpy().tolist(), z.cpu().data.numpy().tolist())) if self.print_batch_progress: print_progessbar(b, n_batch, Name='\t\tEvaluation Batch', Size=40, erase=True) if save_tSNE: if print_to_logger: logger.info("Computing the t-SNE representation.") # Apply t-SNE transform on embeddings index, label, scores, h, z = zip(*idx_label_scores) h, z = np.array(h), np.array(z) h = TSNE(n_components=2).fit_transform(h) z = TSNE(n_components=2).fit_transform(z) self.eval_repr = list( zip(index, label, scores, h.tolist(), z.tolist())) if print_to_logger: logger.info("Succesfully computed the t-SNE representation ") if return_auc: _, label, scores, _, _ = idx_label_scores auc = roc_auc_score(np.array(label), np.array(scores)) return auc