def evaluate_AE(self, dataset, batch_size=32, n_job_dataloader=0, device='cuda', print_batch_progress=False, set='test'): """ Evaluate the AE to get the embedding. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to evaluate | the ae_net. It must return the input, label, mask, semi- | supervised label and the index. |---- batch_size (int) the batch_size to use. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. |---- set (str) The nature of the set evaluated. Must be either 'valid' or 'test'. OUTPUT |---- None """ assert set in ['valid', 'test'], f'Invalid set provided : {set} was given. Expected either valid or test.' if self.ae_trainer is None: self.ae_trainer = AE_trainer(batch_size=batch_size, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # Evaluate model self.ae_trainer.evaluate(self.ae_net, dataset, save_tSNE=True, return_auc=False, print_to_logger=True) # get results self.results['AE'][set]['embedding'] = self.ae_trainer.eval_repr
def train_AE(self, dataset, valid_dataset=None, n_epoch=100, batch_size=32, lr=1e-3, weight_decay=1e-6, lr_milestone=(), n_job_dataloader=0, device='cuda', print_batch_progress=False): """ Pretrain the encoder with AE learning. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to train | the ae_net. It must return the input image the mask and | the semi-supervised label. |---- n_epoch (int) the number of epoch. |---- batch_size (int) the batch_size to use. |---- lr (float) the learning rate. |---- weight_decay (float) the weight_decay for the Adam optimizer. |---- lr_milestone (tuple) the lr update steps. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. OUTPUT |---- None """ self.ae_trainer = AE_trainer(n_epoch=n_epoch, batch_size=batch_size, lr=lr, weight_decay=weight_decay, lr_milestone=lr_milestone, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # train SimCLR self.ae_net = self.ae_trainer.train(self.ae_net, dataset, valid_dataset=valid_dataset) # get results self.results['AE']['train']['time'] = self.ae_trainer.train_time self.results['AE']['train']['loss'] = self.ae_trainer.train_loss
class AE_DSAD: """ Define a Deep SAD encoder with a AE pretraining of the encoder. """ def __init__(self, ae_net, AD_net, eta=1.0): """ Build a AE_DSAD. ---------- INPUT |---- ae_net (nn.Module) the network to use for the AE learning. |---- AD_net (nn.Module) the DSAD network for anomaly detection. |---- eta (float) the semi-supervised importance in the DSAD loss. OUTPUT |---- None """ self.ae_net = ae_net self.AD_net = AD_net self.c = None self.eta = eta self.ae_trainer = None self.AD_trainer = None self.results = { 'AE':{ 'train':{ 'time': None, 'loss': None }, 'valid':{ 'embedding': None }, 'test':{ 'embedding': None } }, 'AD':{ 'train':{ 'time': None, 'loss': None }, 'valid':{ 'time': None, 'auc': None, 'scores': None }, 'test':{ 'time': None, 'auc': None, 'scores': None } } } def train_AE(self, dataset, valid_dataset=None, n_epoch=100, batch_size=32, lr=1e-3, weight_decay=1e-6, lr_milestone=(), n_job_dataloader=0, device='cuda', print_batch_progress=False): """ Pretrain the encoder with AE learning. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to train | the ae_net. It must return the input image the mask and | the semi-supervised label. |---- n_epoch (int) the number of epoch. |---- batch_size (int) the batch_size to use. |---- lr (float) the learning rate. |---- weight_decay (float) the weight_decay for the Adam optimizer. |---- lr_milestone (tuple) the lr update steps. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. OUTPUT |---- None """ self.ae_trainer = AE_trainer(n_epoch=n_epoch, batch_size=batch_size, lr=lr, weight_decay=weight_decay, lr_milestone=lr_milestone, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # train SimCLR self.ae_net = self.ae_trainer.train(self.ae_net, dataset, valid_dataset=valid_dataset) # get results self.results['AE']['train']['time'] = self.ae_trainer.train_time self.results['AE']['train']['loss'] = self.ae_trainer.train_loss def evaluate_AE(self, dataset, batch_size=32, n_job_dataloader=0, device='cuda', print_batch_progress=False, set='test'): """ Evaluate the AE to get the embedding. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to evaluate | the ae_net. It must return the input, label, mask, semi- | supervised label and the index. |---- batch_size (int) the batch_size to use. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. |---- set (str) The nature of the set evaluated. Must be either 'valid' or 'test'. OUTPUT |---- None """ assert set in ['valid', 'test'], f'Invalid set provided : {set} was given. Expected either valid or test.' if self.ae_trainer is None: self.ae_trainer = AE_trainer(batch_size=batch_size, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # Evaluate model self.ae_trainer.evaluate(self.ae_net, dataset, save_tSNE=True, return_auc=False, print_to_logger=True) # get results self.results['AE'][set]['embedding'] = self.ae_trainer.eval_repr def save_ae_net(self, export_fn): """ Save the AE model. ---------- INPUT |---- export_fn (str) the export path. OUTPUT |---- None """ torch.save({'ae_net_dict': self.ae_net.state_dict()}, export_fn) def load_ae_net(self, import_fn, map_location='cpu'): """ Load the AE model. ---------- INPUT |---- import_fn (str) path where to get the model. |---- map_location (str) device on which to load the model. OUTPUT |---- None """ model = torch.load(import_fn, map_location=map_location) self.ae_net.load_state_dict(model['ae_net_dict']) def transfer_encoder(self): """ Transfer the weight of the encoder learnt by AE learning to the encoder of DSAD. ---------- INPUT |---- None OUTPUT |---- None """ # get both encoder state dicts AD_encoder_dict = self.AD_net.encoder.state_dict() AE_encoder_dict = self.ae_net.encoder.state_dict() # keep common keys new_encoder_dict = {k:v for k, v in AE_encoder_dict.items() if k in AD_encoder_dict} # update classifer network encoder weights with the ae_net encoder ones. AD_encoder_dict.update(new_encoder_dict) self.AD_net.encoder.load_state_dict(AD_encoder_dict) def train_AD(self, dataset, valid_dataset=None, n_epoch=100, batch_size=32, lr=1e-3, weight_decay=1e-6, lr_milestone=(), n_job_dataloader=0, device='cuda', print_batch_progress=False): """ Train the encoder on the DSAD objective. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to train | the AD_net. It must return the input, label, mask, semi- | supervised label and the index. |---- valid_dataset (torch.utils.data.Dataset) The dataset on which | to validate the model at each epoch. No validation is | performed if not provided. |---- n_epoch (int) the number of epoch. |---- batch_size (int) the batch_size to use. |---- lr (float) the learning rate. |---- weight_decay (float) the weight_decay for the Adam optimizer. |---- lr_milestone (tuple) the lr update steps. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. OUTPUT |---- None """ self.AD_trainer = DSAD_trainer(self.c, self.eta, n_epoch=n_epoch, batch_size=batch_size, lr=lr, weight_decay=weight_decay, lr_milestone=lr_milestone, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # Train classifer self.AD_net = self.AD_trainer.train(dataset, self.AD_net, valid_dataset=valid_dataset) # get results self.results['AD']['train']['time'] = self.AD_trainer.train_time self.results['AD']['train']['loss'] = self.AD_trainer.train_loss self.c = self.AD_trainer.c.cpu().data.numpy().tolist() def evaluate_AD(self, dataset, batch_size=32, n_job_dataloader=0, device='cuda', print_batch_progress=False, set='test'): """ Evaluate the encoder on the DSAD objective. ---------- INPUT |---- dataset (torch.utils.data.Dataset) The dataset on which to evaluate | the AD_net. It must return the input, label, mask, semi- | supervised label and the index. |---- batch_size (int) the batch_size to use. |---- n_jobs_dataloader (int) number of workers for the dataloader. |---- device (str) the device to work on ('cpu' or 'cuda'). |---- print_batch_progress (bool) whether to dispay the batch | progress bar. |---- set (str) The nature of the set evaluated. Must be either 'valid' or 'test'. OUTPUT |---- None """ assert set in ['valid', 'test'], f'Invalid set provided : {set} was given. Expected either valid or test.' if self.AD_trainer is None: self.AD_trainer = DSAD_trainer(self.c, self.eta, batch_size=batch_size, n_job_dataloader=n_job_dataloader, device=device, print_batch_progress=print_batch_progress) # Evaluate model self.AD_trainer.evaluate(self.AD_net, dataset, return_auc=False, print_to_logger=True, save_tSNE=True) # get results self.results['AD'][set]['time'] = self.AD_trainer.eval_time self.results['AD'][set]['auc'] = self.AD_trainer.eval_auc self.results['AD'][set]['scores'] = self.AD_trainer.eval_scores def save_AD(self, export_fn): """ Save the DSAD model (center and state dictionary). ---------- INPUT |---- export_fn (str) the export path. OUTPUT |---- None """ torch.save({'c': self.c, 'AD_net_dict': self.AD_net.state_dict()}, export_fn) def load_AD(self, import_fn, map_location='cpu'): """ Load the DSAD model from location. ---------- INPUT |---- import_fn (str) path where to get the model. |---- map_location (str) device on which to load the model. OUTPUT |---- None """ model = torch.load(import_fn, map_location=map_location) self.c = model['c'] self.AD_net.load_state_dict(model['AD_net_dict']) def save_results(self, export_fn): """ Save the model results in JSON. ---------- INPUT |---- export_fn (str) path where to get the results. OUTPUT |---- None """ with open(export_fn, 'w') as fn: json.dump(self.results, fn)