def train(self, dataset: BaseADDataset, net: BaseNet): logger = logging.getLogger() # Get train data loader train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader) # Set device net = net.to(self.device) # Use importance weighted sampler (Burda et al., 2015) to get a better estimate on the log-likelihood. sampler = ImportanceWeightedSampler(mc=1, iw=1) elbo = SVI(net, likelihood=binary_cross_entropy, sampler=sampler) # Set optimizer (Adam optimizer for now) optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Set learning rate scheduler scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=self.lr_milestones, gamma=0.1) # Training logger.info('Starting training...') start_time = time.time() net.train() for epoch in range(self.n_epochs): scheduler.step() if epoch in self.lr_milestones: logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0])) epoch_loss = 0.0 n_batches = 0 epoch_start_time = time.time() for data in train_loader: inputs, labels, semi_targets, _ = data inputs = inputs.to(self.device) labels = labels.to(self.device) semi_targets = semi_targets.to(self.device) # Get labeled and unlabeled data and make labels one-hot inputs = inputs.view(inputs.size(0), -1) x = inputs[semi_targets != 0] u = inputs[semi_targets == 0] y = labels[semi_targets != 0] if y.nelement() > 1: y_onehot = torch.Tensor(y.size(0), 2).to( self.device) # two labels: 0: normal, 1: outlier y_onehot.zero_() y_onehot.scatter_(1, y.view(-1, 1), 1) # Zero the network parameter gradients optimizer.zero_grad() # Update network parameters via backpropagation: forward + backward + optimize if y.nelement() < 2: L = torch.tensor(0.0).to(self.device) else: L = -elbo(x, y_onehot) U = -elbo(u) # Regular cross entropy if y.nelement() < 2: classication_loss = torch.tensor(0.0).to(self.device) else: # Add auxiliary classification loss q(y|x) logits = net.classify(x) eps = 1e-8 classication_loss = torch.sum(y_onehot * torch.log(logits + eps), dim=1).mean() # Overall loss loss = L - self.alpha * classication_loss + U # J_alpha loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 # log epoch statistics epoch_train_time = time.time() - epoch_start_time logger.info( f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s ' f'| Train Loss: {epoch_loss / n_batches:.6f} |') self.train_time = time.time() - start_time logger.info('Training Time: {:.3f}s'.format(self.train_time)) logger.info('Finished training.') return net
def test(self, dataset: BaseADDataset, net: BaseNet): logger = logging.getLogger() # Get test data loader _, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader) # Set device net = net.to(self.device) # Use importance weighted sampler (Burda et al., 2015) to get a better estimate on the log-likelihood. sampler = ImportanceWeightedSampler(mc=1, iw=1) elbo = SVI(net, likelihood=binary_cross_entropy, sampler=sampler) # Testing logger.info('Starting testing...') epoch_loss = 0.0 n_batches = 0 start_time = time.time() idx_label_score = [] net.eval() with torch.no_grad(): for data in test_loader: inputs, labels, _, idx = data inputs = inputs.to(self.device) labels = labels.to(self.device) idx = idx.to(self.device) # All test data is considered unlabeled inputs = inputs.view(inputs.size(0), -1) u = inputs y = labels y_onehot = torch.Tensor(y.size(0), 2).to( self.device) # two labels: 0: normal, 1: outlier y_onehot.zero_() y_onehot.scatter_(1, y.view(-1, 1), 1) # Compute loss L = -elbo(u, y_onehot) U = -elbo(u) logits = net.classify(u) eps = 1e-8 classication_loss = -torch.sum( y_onehot * torch.log(logits + eps), dim=1).mean() loss = L + self.alpha * classication_loss + U # J_alpha # Compute scores scores = logits[:, 1] # likelihood/confidence for anomalous class as anomaly score # Save triple of (idx, label, score) in a list idx_label_score += list( zip(idx.cpu().data.numpy().tolist(), labels.cpu().data.numpy().tolist(), scores.cpu().data.numpy().tolist())) epoch_loss += loss.item() n_batches += 1 self.test_time = time.time() - start_time self.test_scores = idx_label_score # Compute AUC _, labels, scores = zip(*idx_label_score) labels = np.array(labels) scores = np.array(scores) self.test_auc = roc_auc_score(labels, scores) # Log results logger.info('Test Loss: {:.6f}'.format(epoch_loss / n_batches)) logger.info('Test AUC: {:.2f}%'.format(100. * self.test_auc)) logger.info('Test Time: {:.3f}s'.format(self.test_time)) logger.info('Finished testing.')