Beispiel #1
0
def test_AE(ae_net,
            dataset,
            batch_size=16,
            n_jobs_dataloader=4,
            device='cuda'):
    # make test dataloader using image and mask
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \
                                    shuffle=True, num_workers=n_jobs_dataloader)

    # MSE loss without reduction --> MSE loss for each output pixels
    criterion = MaskedMSELoss(reduction='none')

    # set to device
    ae_net = ae_net.to(device)
    criterion = criterion.to(device)

    # Testing
    epoch_loss = 0.0
    n_batch = 0
    start_time = time.time()
    idx_label_score = []
    # put network in evaluation mode
    ae_net.eval()

    with torch.no_grad():
        for b, data in enumerate(loader):
            input, label, mask, _, idx = data
            # put inputs to device
            input, label = input.to(device).float(), label.to(device)
            mask, idx = mask.to(device), idx.to(device)

            rec = ae_net(input)
            rec_loss = criterion(rec, input, mask)
            score = torch.mean(rec_loss, dim=tuple(range(
                1, rec.dim())))  # mean over all dimension per batch

            # append scores and label
            idx_label_score += list(
                zip(idx.cpu().data.numpy().tolist(),
                    label.cpu().data.numpy().tolist(),
                    score.cpu().data.numpy().tolist()))
            # overall batch loss
            loss = torch.sum(rec_loss) / torch.sum(mask)
            epoch_loss += loss.item()
            n_batch += 1

            print_progessbar(b, loader.__len__(), Name='\t\tBatch', Size=20)

    test_time = time.time() - start_time
    scores = idx_label_score
    return test_time, scores
    def test(self, dataset, net):
        """
        Test the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is tested. It must return an image and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to test. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss(reduction='none')
        criterion_ad = self.SVDDLoss(self.space_repr,
                                     self.nu,
                                     eps=self.eps,
                                     soft_boundary=self.soft_boundary)

        # Testing
        logger.info('>>> Start Testing the joint DeepSVDD and AutoEncoder.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = test_loader.__len__()
        start_time = time.time()
        idx_label_score_rec = []
        idx_label_score_ad = []
        # put network in evaluation mode
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(test_loader):
                input, label, mask, semi_label, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(
                    self.device)
                mask, semi_label = mask.to(self.device), semi_label.to(
                    self.device)
                idx = idx.to(self.device)

                # mask the input
                input = input * mask

                # compute loss
                rec, embed = net(input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_ad = criterion_ad(embed, self.R)
                # compute anomaly scores
                rec_score = torch.mean(
                    loss_rec, dim=tuple(range(
                        1, rec.dim())))  # mean over all dimension per batch
                if self.use_subspace:
                    dist = torch.sum(
                        (embed -
                         torch.matmul(self.space_repr, embed.transpose(
                             0, 1)).transpose(0, 1))**2,
                        dim=1
                    )  # score is the distance (large distances highlight anomalies)
                else:
                    dist = torch.sum(
                        (embed - self.space_repr)**2, dim=1
                    )  # score is the distance (large distances highlight anomalies)

                if self.soft_boundary:
                    ad_score = dist - self.R**2
                else:
                    ad_score = dist

                # get overall loss
                mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask)
                loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec
                loss += self.scale_em * self.criterion_weight[1] * loss_ad

                # append scores and label
                idx_label_score_rec += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.test_time = time.time() - start_time
        self.test_scores_rec = idx_label_score_rec
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        self.test_auc_rec = roc_auc_score(label, rec_score)
        self.test_f1_rec = f1_score(
            label, np.where(rec_score > self.scores_threhold_rec, 1, 0))

        self.test_scores_ad = idx_label_score_ad
        _, label, ad_score = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        self.test_auc_ad = roc_auc_score(label, ad_score)
        self.test_f1_ad = f1_score(
            label, np.where(ad_score > self.scores_threhold_ad, 1, 0))

        # add info to logger
        logger.info(f'>>> Test Time: {self.test_time:.3f} [s]')
        logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Test reconstruction AUC: {self.test_auc_rec:.3%}')
        logger.info(
            f'>>> Test F1-score on reconstruction score: {self.test_f1_rec:.3%}'
        )
        logger.info(f'>>> Test AD AUC: {self.test_auc_ad:.3%}')
        logger.info(
            f'>>> Test F1-score on DeepSVDD score: {self.test_f1_ad:.3%}')
        logger.info(
            '>>> Finished Testing the Joint DeepSVDD and AutoEncoder.\n')
    def train(self, dataset, net, valid_dataset=None):
        """
        Train the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The trained joint DeepSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True  # enable the network to provide the SVDD embdeding

        # initialize hypersphere center or subspace projection matrix
        if self.space_repr is None:
            if self.use_subspace:
                logger.info('>>> Initializing the subspace projection matrix.')
                self.space_repr = self.initialize_projection_matrix(
                    train_loader, net)
                logger.info('>>> Projection matrix succesfully initialized.')
            else:
                logger.info('>>> Initializing the hypersphere center.')
                self.space_repr = self.initialize_hypersphere_center(
                    train_loader, net)
                logger.info('>>> Center succesfully initialized.')

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss()
        criterion_ad = self.SVDDLoss(self.space_repr,
                                     self.nu,
                                     eps=self.eps,
                                     soft_boundary=self.soft_boundary)

        # compute the scale weight so that the rec and svdd losses are scalled and comparable
        self.initialize_loss_scale_weight(train_loader, net, criterion_rec,
                                          criterion_ad)

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the Joint DeepSVDD and Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()
            dist = []

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(
                    self.device).float(), mask.to(self.device), semi_label.to(
                        self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation on the two criterion
                rec, embed = net(input)
                # reconstruction loss
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1,
                    rec, input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec
                # SVDD embedding loss
                loss_ad = criterion_ad(embed, self.R)
                loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad
                loss = loss_rec + loss_ad

                loss.backward()
                optimizer.step()

                # compute dist to update radius R
                if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up):
                    if self.use_subspace:
                        dist.append(
                            torch.sum((embed - torch.matmul(
                                self.space_repr, embed.transpose(
                                    0, 1)).transpose(0, 1))**2,
                                      dim=1).detach())
                    else:
                        dist.append(
                            torch.sum((self.space_repr - embed)**2,
                                      dim=1).detach())

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            # update radius
            if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up):
                self.R.data = torch.tensor(self.get_radius(
                    torch.cat(dist, dim=0)),
                                           device=self.device)

            valid_auc = ''
            if valid_dataset:
                auc_rec, auc_ad = self.validate(valid_dataset,
                                                net,
                                                final=False)
                net.train()
                valid_auc = f' Rec AUC: {auc_rec:.3%} | AD AUC: {auc_ad:.3%} | R {self.R:.3f} |'

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | Train Time: {epoch_train_time:.3f} [s] '
                f'| Train Loss: {epoch_loss / n_batch:.6f} |' + valid_auc)

            # append the epoch loss to results list
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'>>> Training of Joint DeepSVDD and AutoEncoder Time: {self.train_time:.3f} [s]'
        )
        logger.info('>>> Finished Joint DeepSVDD and AutoEncoder Training.\n')

        return net
    def pretrain(self, dataset, net):
        """
        Pretrain the AE for the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DeepSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = False

        # define the two criterion for reconstruction
        criterion_rec = MaskedMSELoss()

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Start training
        logger.info('>>> Start Pretraining the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()

        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch_pretrain):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(
                    self.device).float(), mask.to(self.device), semi_label.to(
                        self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation
                rec, _ = net(input)
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1,
                    rec, input)
                loss = criterion_rec(rec, input, mask)

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epoch_pretrain:03} | Pretrain Time: {epoch_train_time:.3f} [s] '
                f'| Pretrain Loss: {epoch_loss / n_batch:.6f} |')

            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

        # End training
        self.pretrain_loss = epoch_loss_list
        self.pretrain_time = time.time() - start_time
        logger.info(
            f'>>> Pretraining of AutoEncoder Time: {self.pretrain_time:.3f} [s]'
        )
        logger.info('>>> Finished of AutoEncoder Pretraining.\n')

        return net
Beispiel #5
0
    def validate(self, dataset, net):
        """
        Validate the joint DMSVDD network on the provided dataset and find the
        best threshold on the score to maximize the f1-score.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image and
            |           semi-supervized labels.
            |---- net (nn.Module) The DMSVDD to validate. The network should be
            |           an autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        valid_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss(reduction='none')
        criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary)

        # Testing
        logger.info('>>> Start Validating of the joint DMSVDD and AutoEncoder.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = valid_loader.__len__()
        start_time = time.time()
        idx_label_score_rec = []
        idx_label_score_ad = []
        # put network in evaluation mode
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(valid_loader):
                input, label, mask, semi_label, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(self.device)
                mask, semi_label = mask.to(self.device), semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask the input
                input = input * mask

                # compute loss
                rec, embed = net(input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_ad = criterion_ad(embed, self.c, self.R)
                # compute anomaly scores
                rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch

                dist, idx = torch.min(torch.sum((self.c.unsqueeze(0) - embed.unsqueeze(1))**2, dim=2), dim=1) # dist and idx by batch
                if self.soft_boundary:
                    ad_score = dist - torch.stack([self.R[i] ** 2 for i in idx], dim=0) #dist - self.R ** 2 --> negative = normal ; positive = abnormal
                else:
                    ad_score = dist

                # compute overall loss
                mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask)
                loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec
                loss += self.scale_em * self.criterion_weight[1] * loss_ad

                # append scores and label
                idx_label_score_rec += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.valid_time = time.time() - start_time
        self.valid_scores_rec = idx_label_score_rec
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        self.valid_auc_rec = roc_auc_score(label, rec_score)
        self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(rec_score, label, metric=f1_score)

        self.valid_scores_ad = idx_label_score_ad
        _, label, ad_score = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        self.valid_auc_ad = roc_auc_score(label, ad_score)
        self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(ad_score, label, metric=f1_score)

        # add info to logger
        logger.info(f'>>> Validation Time: {self.valid_time:.3f} [s]')
        logger.info(f'>>> Validation Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Validation reconstruction AUC: {self.valid_auc_rec:.3%}')
        logger.info(f'>>> Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}')
        logger.info(f'>>> Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}')
        logger.info(f'>>> Validation DMSVDD AUC: {self.valid_auc_ad:.3%}')
        logger.info(f'>>> Best Threshold for the DMSVDD score maximizing F1-score: {self.scores_threhold_ad:.3f}')
        logger.info(f'>>> Best F1-score on DMSVDD score: {self.valid_f1_ad:.3%}')
        logger.info('>>> Finished validating the Joint DMSVDD and AutoEncoder.\n')
    def evaluate(self, net, dataset, mode='test', final=False):
        """
        Evaluate the model with the given dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DMSAD to validate. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image, a mask and
            |           semi-supervized labels.
            |---- mode (str) either 'valid' or 'test'. Define the evaluation mode.
            |           In 'valid' the evaluation can return the reconstruction
            |           and MSAD AUCs and compute the best threshold to maximize
            |           the F1-scores. In test mode the validation threshold is
            |           used to compute the F1-score.
            |---- final (bool) whether the call represents the final validation,
            |           in which case the validation results are saved. Only
            |           relevant if mode is 'valid'.
        OUTPUT
            |---- auc (tuple (reconstruction auc, ad auc)) the validation AUC for
            |           both scores are return only if final is False. Else None
            |           is return.
        """
        assert mode in [
            'valid', 'test'
        ], f'Mode {mode} is not supported. Should be either "valid" or "test".'
        logger = logging.getLogger()

        # make the dataloader
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide both the reconstruction and the embedding
        net.return_svdd_embed = True

        # define the two loss function
        loss_fn_rec = MaskedMSELoss(
            reduction='none'
        )  # no reduction to compute AD score for each sample
        loss_fn_ad = DMSADLoss(self.eta, self.eps)

        # Validate
        if final or mode == 'test':
            logger.info(
                f' Start Evaluating the jointly trained DMSAD and AutoEncoder in {mode} mode.'
            )
        epoch_loss = 0.0
        n_batch = len(loader)
        start_time = time.time()
        idx_label_score_rec, idx_label_score_ad = [], [
        ]  # placeholder for scores

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data on device
                input, label, mask, semi_label, idx = data
                input = input.to(self.device).float()
                label = label.to(self.device)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask input
                input = input * mask

                # compute the loss
                rec, embed = net(input)
                loss_rec = loss_fn_rec(rec, input, mask)
                loss_ad = loss_fn_ad(embed, self.c, semi_label)
                # get reconstruction anomaly scores : mean loss by sample
                rec_score = torch.mean(loss_rec,
                                       dim=tuple(range(1, rec.dim())))

                # find closest sphere
                dist, sphere_idx = torch.min(torch.norm(self.c.unsqueeze(0) -
                                                        embed.unsqueeze(1),
                                                        p=2,
                                                        dim=2),
                                             dim=1)

                if not self.R is None:
                    # anomaly scores positive if dist > R and negative if dist < R
                    ad_score = dist - torch.stack(
                        [self.R[j] for j in sphere_idx], dim=0)
                else:
                    # else scores is just the minimal distance to a center
                    ad_score = dist

                # append scores to the placeholer lists
                idx_label_score_rec += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist(),
                        sphere_idx.cpu().data.numpy().tolist(),
                        embed.cpu().data.numpy().tolist()))

                # compute the overall loss
                loss = self.scale_rec * self.criterion_weight[0] * (
                    torch.sum(loss_rec) / torch.sum(mask))
                loss += self.scale_ad * self.criterion_weight[1] * loss_ad
                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\t Evaluation Batch',
                                     Size=40,
                                     erase=True)

        # compute AUCs
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        auc_rec = roc_auc_score(label, rec_score)

        _, label, ad_score, _, _ = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        auc_ad = roc_auc_score(label, ad_score)

        if mode == 'valid':
            if final:
                # save results
                self.valid_time = time.time() - start_time
                self.valid_scores_rec = idx_label_score_rec
                self.valid_auc_rec = auc_rec
                self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(
                    rec_score, label, metric=f1_score)
                self.valid_scores_ad = idx_label_score_ad
                self.valid_auc_ad = auc_ad
                self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(
                    ad_score, label, metric=f1_score)
                # print infos
                logger.info(f'---- Validation Time: {self.valid_time:.3f} [s]')
                logger.info(
                    f'---- Validation Loss: {epoch_loss / n_batch:.6f}')
                logger.info(
                    f'---- Validation reconstruction AUC: {self.valid_auc_rec:.3%}'
                )
                logger.info(
                    f'---- Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}'
                )
                logger.info(
                    f'---- Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}'
                )
                logger.info(
                    f'---- Validation MSAD AUC: {self.valid_auc_ad:.3%}')
                logger.info(
                    f'---- Best Threshold for the MSAD score maximizing F1-score: {self.scores_threhold_ad:.3f}'
                )
                logger.info(
                    f'---- Best F1-score on MSAD score: {self.valid_f1_ad:.3%}'
                )
                logger.info(
                    '---- Finished validating the Joint DMSAD and AutoEncoder.\n'
                )
            else:
                return auc_rec, auc_ad

        elif mode == 'test':
            # save results
            self.test_time = time.time() - start_time
            self.test_scores_rec = idx_label_score_rec
            self.test_auc_rec = auc_rec
            self.test_scores_ad = idx_label_score_ad
            self.test_auc_ad = auc_ad

            # print infos
            logger.info(f'---- Test Time: {self.test_time:.3f} [s]')
            logger.info(f'---- Test Loss: {epoch_loss / n_batch:.6f}')
            logger.info(
                f'---- Test reconstruction AUC: {self.test_auc_rec:.3%}')
            if self.scores_threhold_rec is not None:
                self.test_f1_rec = f1_score(
                    label, np.where(rec_score > self.scores_threhold_rec, 1,
                                    0))
                logger.info(
                    f'---- Best F1-score on reconstruction score: {self.test_f1_rec:.3%}'
                )
            logger.info(f'---- Test MSAD AUC: {self.test_auc_ad:.3%}')
            if self.scores_threhold_ad is not None:
                self.test_f1_ad = f1_score(
                    label, np.where(ad_score > self.scores_threhold_ad, 1, 0))
                logger.info(
                    f'---- Best F1-score on MSAD score: {self.test_f1_ad:.3%}')
            logger.info(
                '---- Finished testing the Joint DMSAD and AutoEncoder.\n')
Beispiel #7
0
    def train(self, dataset, net):
        """
        Train the DMSVDD on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DMSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True # enable the network to provide the SVDD embdeding

        # initialize hypersphere center or subspace projection matrix
        if self.c is None:
            logger.info('>>> Initializing the hyperspheres centers.')
            self.initialize_centers(train_loader, net)
            logger.info(f'>>> {self.n_sphere_init} centers succesfully initialized.')

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss()
        criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary)

        # compute the scale weight so that the rec and svdd losses are scalled and comparable
        self.initialize_loss_scale_weight(train_loader, net, criterion_rec, criterion_ad)

        # define optimizer
        optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the Joint DMSVDD and Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            # update network
            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation on the two criterion
                rec, embed = net(input)
                # reconstruction loss
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(semi_label.view(-1,1,1,1).expand(*input.shape) != -1, rec, input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec
                # SVDD embedding loss
                loss_ad = criterion_ad(embed, self.c, self.R)
                loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad
                loss = loss_rec + loss_ad

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tWeight-update Batch', Size=20)

            with torch.no_grad():
                # update radius R
                if epoch >= self.n_epoch_warm_up:
                    n_k = torch.zeros(self.c.shape[0], device=self.device)
                    dist = [[] for _ in range(self.c.shape[0])] # list of list for each center : N_center x N_k

                    for b, data in enumerate(train_loader):
                        # compute distance and belonging of sample
                        input, _, mask, semi_label, _ = data
                        input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device)

                        # mask the input (keep only the object)
                        input = (input * mask)[semi_label != -1]
                        _, embed = net(input)
                        # get closest centers
                        min_dist, idx = torch.min(torch.norm(self.c.unsqueeze(0) - embed.unsqueeze(1), p=2, dim=2), dim=1)
                        for i, d in zip(idx, min_dist):
                            n_k[i] += 1
                            dist[i].append(d)

                        if self.print_batch_progress:
                            print_progessbar(b, n_batch_tot, Name='\t\tRadius-update Batch', Size=20)

                    if self.soft_boundary:
                        # update R with (1-nu)th quantile
                        self.R = torch.where(n_k < self.nu * torch.max(n_k),
                                             torch.Tensor([0.0]).to(self.device),
                                             torch.Tensor([np.quantile(torch.stack(d, dim=0).clone().cpu().numpy(), 1 - self.nu) if len(d) > 0 else 0.0 for d in dist]).to(self.device))

                        # keep only centers and radius where R > 0
                        self.c = self.c[self.R > 0.0]
                        self.R = self.R[self.R > 0.0]
                    else:
                        # keep only centers that are not represented
                        self.c = self.c[n_k == 0] #self.c = self.c[n_k < self.nu * torch.max(n_k)]

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train Time: {epoch_train_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} '
                        f'| N spheres: {self.c.shape[0]:03} |')

            # append the epoch loss to results list
            epoch_loss_list.append([epoch+1, epoch_loss/n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}')

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(f'>>> Training of Joint DMSVDD and AutoEncoder Time: {self.train_time:.3f} [s]')
        logger.info('>>> Finished Joint DMSVDD and AutoEncoder Training.\n')

        return net
    def pretrain(self, net, dataset):
        """
        Pretrain the AE for the joint DMSAD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is pretrained. It must return a tuple (image,
            |           label, mask, semi-supervized labels, idx).
            |---- net (nn.Module) The DMSAD to pretrain. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSAD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide only the reconstruction
        net.return_svdd_embed = False

        # define the reconstruvtion loss function
        loss_fn_rec = MaskedMSELoss()

        # define the optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Start training
        logger.info(' Start Pretraining the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = train_loader.__len__()

        for epoch in range(self.n_epoch_pretrain):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get batch data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # mask the input and keep only normal samples
                input = (input * mask)[semi_label != -1]
                mask = mask[semi_label != -1]

                # Update network parameters via backpropagation : forward + backward + optim
                optimizer.zero_grad()
                rec, _ = net(input)
                loss = loss_fn_rec(rec, input, mask)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            # print epoch statstics
            logger.info(
                f'----| Epoch {epoch + 1:03}/{self.n_epoch_pretrain:03} '
                f'| Pretrain Time {time.time() - epoch_start_time:.3f} [s] '
                f'| Pretrain Loss {epoch_loss / n_batch:.6f} |')
            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

        # End training
        self.pretrain_loss = epoch_loss_list
        self.pretrain_time = time.time() - start_time
        logger.info(
            f'---- Finished Pretraining the AutoEncoder in {self.pretrain_time:.3f} [s].'
        )

        return net
    def train(self, net, dataset, valid_dataset=None):
        """
        Train the DMSAD on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DMSAD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset on which
            |           to validate the model at each epoch. No validation is
            |           performed if not provided.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSAD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide both the reconstruction and the embedding
        net.return_svdd_embed = True

        # Initialize the hyper-sphere centers by Kmeans
        if self.c is None:
            logger.info(' Initializing the hypersheres centers.')
            self.initialize_centers(train_loader, net)
            logger.info(
                f' {self.c.shape[0]} centers successfully initialized.')

        # define the reconstruvtion loss function
        loss_fn_rec = MaskedMSELoss()
        loss_fn_ad = DMSADLoss(self.eta, eps=self.eps)

        # Compute the scaling factors for the reconstruction and DMSAD losses
        logger.info(' Initializing the loss scale factors.')
        self.initialize_loss_scale_weight(train_loader, net, loss_fn_rec,
                                          loss_fn_ad)
        logger.info(
            f' reconstruction loss scale factor initialized to {self.scale_rec:.6f}'
        )
        logger.info(
            f' MSAD embdeding loss scale factor initialized to {self.scale_ad:.6f}'
        )

        # define the optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define the learning rate scheduler : 90% reduction at each steps
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info(' Start Training Jointly the DMSAD and the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()
            n_k = torch.zeros(self.c.shape[0], device=self.device)

            for b, data in enumerate(train_loader):
                # get batch data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # mask input
                input = input * mask

                # Update the network by backpropagation using the two losses.
                optimizer.zero_grad()
                rec, embed = net(input)
                # reconstruction loss only on normal sample (loss of zero for abnormal)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != 1,
                    rec, input)
                loss_rec = self.scale_rec * self.criterion_weight[
                    0] * loss_fn_rec(rec, input, mask)
                # DMSAD loss
                loss_ad = self.scale_ad * self.criterion_weight[
                    1] * loss_fn_ad(embed, self.c, semi_label)
                # total loss
                loss = loss_rec + loss_ad
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                # get the closest sphere and count the number of normal samples per sphere
                idx = torch.argmin(torch.norm(self.c.unsqueeze(0) -
                                              embed.unsqueeze(1),
                                              p=2,
                                              dim=2),
                                   dim=1)
                for i in idx[semi_label != -1]:
                    n_k[i] += 1

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            # remove centers with less than gamma fraction of largest hypersphere number of sample
            self.c = self.c[n_k >= self.gamma * torch.max(n_k)]

            # intermediate validation of the model if required
            valid_auc = ''
            if valid_dataset:
                auc_rec, auc_ad = self.evaluate(net,
                                                valid_dataset,
                                                mode='valid',
                                                final=False)
                valid_auc = f' Rec AUC {auc_rec:.3%} | MSAD AUC {auc_ad:.3%} |'

            # print epoch statstics
            logger.info(
                f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train Time {time.time() - epoch_start_time:.3f} [s] '
                f'| Train Loss {epoch_loss / n_batch:.6f} '
                f'| N sphere {self.c.shape[0]:03} |' + valid_auc)
            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update learning rate if milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

            # re-initialized loss scale factors after few epochs when the centers are more or less defined
            if epoch + 1 == self.reset_scaling_epoch:
                with torch.no_grad():
                    # Compute the scaling factors for the reconstruction and DMSAD losses
                    logger.info('---- Reinitializing the loss scale factors.')
                    self.initialize_loss_scale_weight(train_loader, net,
                                                      loss_fn_rec, loss_fn_ad)
                    logger.info(
                        f'---- reconstruction loss scale factor reinitialized to {self.scale_rec:.6f}'
                    )
                    logger.info(
                        f'---- MSAD embdeding loss scale factor reinitialized to {self.scale_ad:.6f}'
                    )

        # Set the radius of each sphere as 1-gamma quantile of normal samples distances
        logger.info(
            f'---- Setting the hyperspheres radii as the {1-self.gamma:.1%} quantiles of normal sample distances.'
        )
        self.set_radius(train_loader, net)
        logger.info(f'---- {self.R.shape[0]} radii successufully defined.')

        # End Training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'---- Finished jointly training the DMSAD and the Autoencoder in {self.train_time:.3f} [s].'
        )

        return net
    def train(self, dataset, net, valid_dataset=None):
        """
        Train the ARAE network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervised labels.
            |---- net (nn.Module) The ARAE to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The trained ARAE.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)

        # define the criterions
        criterion_rec = MaskedMSELoss()
        criterion_lat = nn.MSELoss()

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the ARAE.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float().requires_grad_(True)
                semi_label = semi_label.to(self.device)
                mask = mask.to(self.device)

                # mask input
                input = input * mask

                if self.use_PGD:
                    adv_input = self.adversarial_search(input, net)
                else:
                    adv_input = self.FGSM(input, net)

                # pass the adversarial and normal samples through the network
                net.encoding_only = True
                _, lat = net(input)
                net.encoding_only = False
                rec_adv, lat_adv = net(adv_input)
                # compute the loss
                loss_rec = criterion_rec(adv_input, rec_adv, mask)
                loss_lat = criterion_lat(lat, lat_adv)
                loss = loss_rec + self.gamma * loss_lat

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            valid_auc = ''
            if valid_dataset:
                auc = self.validate(valid_dataset, net, final=False)
                net.train()
                valid_auc = f' Rec AUC {auc:.3%} |'

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train Time: {epoch_train_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} |' +
                        valid_auc)

            # append the epoch loss to results list
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(f'>>> Training Time of ARAE: {self.train_time:.3f} [s]')
        logger.info('>>> Finished ARAE Training.\n')

        return net
Beispiel #11
0
    def evaluate(self, net, dataset, mode='test', final='False'):
        """
        Evaluate the model with the given dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DSAD to validate. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image, a mask and
            |           semi-supervized labels.
            |---- mode (str) either 'valid' or 'test'. Define the evaluation mode.
            |           In 'valid' the evaluation can return the reconstruction
            |           and SAD AUCs and compute the best threshold to maximize
            |           the F1-scores. In test mode the validation threshold is
            |           used to compute the F1-score.
            |---- final (bool) whether the call represents the final validation,
            |           in which case the validation results are saved. Only
            |           relevant if mode is 'valid'.
        OUTPUT
            |---- auc (tuple (reconstruction auc, ad auc)) the validation AUC for
            |           both scores are return only if final is False. Else None
            |           is return.
        """
        assert mode in ['valid','test'], f'Mode {mode} is not supported. Should be either "valid" or "test".'
        logger = logging.getLogger()

        # make test dataloader using image and mask
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss(reduction='none')
        criterion_ad = self.SADLoss(self.space_repr, self.eta, eps=self.eps)

        # Testing
        if final or mode == 'test':
            logger.info(f' Start Evaluating the jointly trained DSAD and AutoEncoder in {mode} mode.')
        epoch_loss = 0.0
        n_batch = len(loader)
        start_time = time.time()
        idx_label_score_rec, idx_label_score_ad = [], []

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, label, mask, semi_label, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(self.device)
                mask, semi_label = mask.to(self.device), semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask the input
                input = input * mask

                # compute loss
                rec, embed = net(input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_ad = criterion_ad(embed, semi_label)
                # compute anomaly scores
                rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch
                #rec_score = torch.sum(loss_rec, dim=tuple(range(1, rec.dim()))) / (torch.sum(mask, dim=tuple(range(1, rec.dim()))) + 1) # mean reconstruction MSE on the mask per batch
                if self.use_subspace:
                    ad_score = torch.sum((embed - torch.matmul(self.space_repr, embed.transpose(0,1)).transpose(0,1)) ** 2, dim=1) # score is the distance (large distances highlight anomalies)
                else:
                    ad_score = torch.sum((embed - self.space_repr) ** 2, dim=1) # score is the distance (large distances highlight anomalies)
                # compute overall loss
                mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask)
                loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec
                loss += self.scale_em * self.criterion_weight[1] * loss_ad

                # append scores and label
                idx_label_score_rec += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b, n_batch, Name='\t\tBatch', Size=40, erase=True)

        # compute AUCs
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        auc_rec = roc_auc_score(label, rec_score)

        _, label, ad_score = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        auc_ad = roc_auc_score(label, ad_score)

        if mode == 'valid':
            if final:
                self.valid_time = time.time() - start_time
                self.valid_scores_rec = auc_rec
                self.valid_auc_rec = roc_auc_score(label, rec_score)
                self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(rec_score, label, metric=f1_score)
                self.valid_scores_ad = idx_label_score_ad
                self.valid_auc_ad = auc_ad
                self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(ad_score, label, metric=f1_score)

                # add info to logger
                logger.info(f'---- Validation Time: {self.valid_time:.3f} [s]')
                logger.info(f'---- Validation Loss: {epoch_loss / n_batch:.6f}')
                logger.info(f'---- Validation reconstruction AUC: {self.valid_auc_rec:.3%}')
                logger.info(f'---- Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}')
                logger.info(f'---- Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}')
                logger.info(f'---- Validation SAD AUC: {self.valid_auc_ad:.3%}')
                logger.info(f'---- Best Threshold for the MSAD score maximizing F1-score: {self.scores_threhold_ad:.3f}')
                logger.info(f'---- Best F1-score on SAD score: {self.valid_f1_ad:.3%}')
                logger.info('---- Finished validating the Joint DSAD and AutoEncoder.\n')
            else:
                return auc_rec, auc_ad

        elif mode == 'test':
            # save results
            self.test_time = time.time() - start_time
            self.test_scores_rec = idx_label_score_rec
            self.test_auc_rec = auc_rec
            self.test_scores_ad = idx_label_score_ad
            self.test_auc_ad = auc_ad

            # print infos
            logger.info(f'---- Test Time: {self.test_time:.3f} [s]')
            logger.info(f'---- Test Loss: {epoch_loss / n_batch:.6f}')
            logger.info(f'---- Test reconstruction AUC: {self.test_auc_rec:.3%}')
            if self.scores_threhold_rec is not None:
                self.test_f1_rec = f1_score(label, np.where(rec_score > self.scores_threhold_rec, 1, 0))
                logger.info(f'---- Best F1-score on reconstruction score: {self.test_f1_rec:.3%}')
            logger.info(f'---- Test SAD AUC: {self.test_auc_ad:.3%}')
            if self.scores_threhold_ad is not None:
                self.test_f1_ad = f1_score(label, np.where(ad_score > self.scores_threhold_ad, 1, 0))
                logger.info(f'---- Best F1-score on SAD score: {self.test_f1_ad:.3%}')
            logger.info('---- Finished testing the Joint DSAD and AutoEncoder.\n')
    def test(self, dataset, net):
        """
        Test the ARAE network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is tested. It must return an image, a mask and
            |           semi-supervised labels.
            |---- net (nn.Module) The ARAE to test. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)

        # loss function
        criterion = MaskedMSELoss(reduction='none')

        # Testing
        logger.info('>>> Start Testing of the ARAE.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = test_loader.__len__()
        start_time = time.time()
        idx_label_score = []

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(test_loader):
                input, label, mask, _, idx = data
                # put data to device
                input = input.to(self.device).float()
                label = label.to(self.device).float()
                mask = mask.to(self.device)
                idx = idx.to(self.device)
                # mask input
                input = input * mask

                rec, _ = net(input)
                loss = criterion(rec, input, mask)

                ad_score = torch.mean(loss, dim=tuple(range(
                    1, rec.dim())))  # mean loss over batch

                idx_label_score += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist()))
                # compute the mean reconstruction loss
                loss = torch.sum(loss) / torch.sum(mask)
                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.test_time = time.time() - start_time
        self.test_scores = idx_label_score
        _, label, ad_score = zip(*idx_label_score)
        label, ad_score = np.array(label), np.array(ad_score)
        self.test_auc = roc_auc_score(label, ad_score)
        self.test_f1 = f1_score(
            label, np.where(ad_score > self.scores_threshold, 1, 0))

        # add info to logger
        logger.info(f'>>> Testing Time: {self.test_time:.3f} [s]')
        logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Test AUC: {self.test_auc:.3%}')
        logger.info(f'>>> Test F1-score: {self.test_f1:.3%}')
        logger.info('>>> Finished testing the ARAE.\n')
    def train(self, dataset, ae_net):
        """
        Train the autoencoder network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image and a mask
            |           of where the loss is to be computed.
            |---- ae_net (nn.Module) The autoencoder to train.
        OUTPUT
            |---- ae_net (nn.Module) The trained autoencoder.
        """
        logger = logging.getLogger()

        # make train dataloader using image and mask
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # MSE loss without reduction --> MSE loss for each output pixels
        criterion = MaskedMSELoss()

        # set to device
        ae_net = ae_net.to(self.device)
        criterion = criterion.to(self.device)

        # set optimizer
        optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # set the learning rate scheduler (multiple phase learning)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Training
        logger.info('>>> Start Training the AutoEncoder.')
        start_time = time.time()
        epoch_loss_list = []
        # set the network in train mode
        ae_net.train()

        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                input, _, mask, _, _ = data
                # put inputs to device
                input, mask = input.to(self.device).float(), mask.to(self.device)

                # zero the network gradients
                optimizer.zero_grad()

                # Update network paramters by backpropagation by considering only the loss on the mask
                rec = ae_net(input)
                loss = criterion(rec, input, mask)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, train_loader.__len__(), Name='\t\tBatch', Size=20)

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | Train Time: {epoch_train_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} |')

            epoch_loss_list += [[epoch+1, epoch_loss/n_batch]]

            # apply the scheduler step
            scheduler.step()
            if epoch in self.lr_milestone:
                logger.info('>>> LR Scheduler : new learning rate %g' % float(scheduler.get_lr()[0]))

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(f'>>> Training of AutoEncoder Time: {self.train_time:.3f} [s]')
        logger.info('>>> Finished AutoEncoder Training.\n')

        return ae_net
    def test(self, dataset, ae_net):
        """
        Test the autoencoder network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is tested. It must return an image and a mask
            |           of where the loss is to be computed.
            |---- ae_net (nn.Module) The autoencoder network to test.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # MSE loss without reduction --> MSE loss for each output pixels
        criterion = MaskedMSELoss(reduction='none')

        # set to device
        ae_net = ae_net.to(self.device)
        criterion = criterion.to(self.device)

        # Testing
        logger.info('>>> Start Testing the AutoEncoder.')
        epoch_loss = 0.0
        n_batch = 0
        start_time = time.time()
        idx_label_score = []
        # put network in evaluation mode
        ae_net.eval()

        with torch.no_grad():
            for b, data in enumerate(test_loader):
                input, label, mask, _, idx = data
                # put inputs to device
                input, label = input.to(self.device).float(), label.to(self.device)
                mask, idx = mask.to(self.device), idx.to(self.device)

                rec = ae_net(input)
                rec_loss = criterion(rec, input, mask)
                score = torch.mean(rec_loss, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch

                # append scores and label
                idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            score.cpu().data.numpy().tolist()))
                # overall batch loss
                loss = torch.sum(rec_loss) / torch.sum(mask)
                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, test_loader.__len__(), Name='\t\tBatch', Size=20)

        self.test_time = time.time() - start_time
        self.test_scores = idx_label_score

        # Compute AUC : if AE is good a high reconstruction loss highlights the presence of an anomaly on the image
        _, label, score = zip(*idx_label_score)
        label, score = np.array(label), np.array(score)
        self.test_auc = roc_auc_score(label, score)
        self.test_f1 = f1_score(label, np.where(score > self.scores_threhold, 1, 0))

        # add info to logger
        logger.info(f'>>> Test Time: {self.test_time:.3f} [s]')
        logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Test AUC: {self.test_auc:.3%}')
        logger.info(f'>>> Test F1-score: {self.test_f1:.3%}')
        logger.info('>>> Finished Testing the AutoEncoder.\n')
    def train(self, net, dataset, valid_dataset=None):
        """
        Train the autoencoder network on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The autoencoder to train. It must return two
            |           embedding (after the convolution and after the MLP) as
            |           well as the reconstruction
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image , the label,
            |           a mask, semi-supervised label and the index.
            |---- valid_dataset (torch.utils.data.Dataset) the optional dataset
            |           on which to validate the model at each epoch.
        OUTPUT
            |---- net (nn.Module) The trained autoencoder.
        """
        logger = logging.getLogger()

        # make train dataloader using image and mask
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_job_dataloader)
        # define loss_fn
        loss_fn = MaskedMSELoss()

        # set network on device
        net = net.to(self.device)

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # set the learning rate scheduler (multiple phase learning)
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Training
        logger.info('Start Training AE.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            epoch_start_time = time.time()
            net.train()

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # keep only input that are normal
                #input = input[semi_label != -1]
                #mask = mask[semi_label != -1]
                # mask input
                input = input * mask

                # zero the network gradients
                optimizer.zero_grad()

                # Update network paramters by backpropagation by considering only the loss on the mask
                _, _, rec = net(input)
                loss = loss_fn(rec, input, mask)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            valid_auc = ''
            if valid_dataset:
                auc = self.evaluate(net,
                                    valid_dataset,
                                    save_tSNE=False,
                                    return_auc=True,
                                    print_to_logger=False)
                valid_auc = f' Valid AUC {auc:.6f} |'

            # display epoch statistics
            logger.info(f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Time {time.time() - epoch_start_time:.3f} [s]'
                        f'| Loss {epoch_loss / n_batch:.6f} |' + valid_auc)

            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update learning rate if milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # Save results
        self.train_time = time.time() - start_time
        self.train_loss = epoch_loss_list
        logger.info(f'---- Finished Training AE in {self.train_time:.3f} [s].')

        return net
    def evaluate(self,
                 net,
                 dataset,
                 print_to_logger=True,
                 return_auc=False,
                 save_tSNE=True):
        """
        Evaluate the natwork on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The autoencoder to train. It must return two
            |           embedding (after the convolution and after the MLP) as
            |           well as the reconstruction
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image and a mask
            |           of where the loss is to be computed.
            |---- print_to_logger (bool) whether to print info in logger.
            |---- return_auc (bool) whether to return the computed AUC.
            |---- save_tSNE (bool) whether to save the intermediate representation
            |           as a 2D vector using tSNE.
        OUTPUT
            |---- None
        """
        if print_to_logger:
            logger = logging.getLogger()
        # make dataloader (with drop_last = True to ensure that the loss can be computed)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             num_workers=self.n_job_dataloader)

        # put net on device
        net = net.to(self.device)

        # define loss function
        loss_fn = MaskedMSELoss(reduction='none')

        if print_to_logger:
            logger.info("Start Evaluating AE.")

        idx_label_scores = []
        n_batch = len(loader)

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, label, mask, semi_label, idx = data
                # put inputs to device
                input = input.to(self.device).float().requires_grad_(True)
                label = label.to(self.device)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask input
                input = input * mask

                h, z, rec = net(input)

                # compute score as mean loss over by sample
                rec_loss = loss_fn(rec, input, mask)
                score = torch.mean(rec_loss, dim=tuple(range(1, rec.dim())))

                # append scores : idx label score h z
                idx_label_scores += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        score.cpu().data.numpy().tolist(),
                        h.cpu().data.numpy().tolist(),
                        z.cpu().data.numpy().tolist()))

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tEvaluation Batch',
                                     Size=40,
                                     erase=True)

        if save_tSNE:
            if print_to_logger:
                logger.info("Computing the t-SNE representation.")
            # Apply t-SNE transform on embeddings
            index, label, scores, h, z = zip(*idx_label_scores)
            h, z = np.array(h), np.array(z)
            h = TSNE(n_components=2).fit_transform(h)
            z = TSNE(n_components=2).fit_transform(z)
            self.eval_repr = list(
                zip(index, label, scores, h.tolist(), z.tolist()))

            if print_to_logger:
                logger.info("Succesfully computed the t-SNE representation ")

        if return_auc:
            _, label, scores, _, _ = idx_label_scores
            auc = roc_auc_score(np.array(label), np.array(scores))
            return auc