def train(self, dataset, n_cluster=500):
        """
        Train the nearest neighbors for the ablation study. The training consist
        in summarising the train data into n_cluster points obtained by Kmeans.
        The distance from normal points is obtained by Nearest neighbors with
        those centers.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is evaluated.
            |---- n_cluster (int) the number of center to summaries the training data.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make dataloader
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             num_workers=self.n_job_dataloader)

        # get representation of normal samples with net
        logger.info(
            'Getting normal train sample representation for the Ablation Study.'
        )
        repr = []
        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                # mask input and keep only normal samples
                input = (input * mask)[semi_label != -1]
                # get embdeding of batch
                embed = self.net(
                    input
                )[0]  # first element returned is the transfered representation
                repr.append(embed)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            repr = torch.cat(repr, dim=0).cpu().numpy()

        # Apply Kmeans algorithm on embedding
        logger.info(
            f'Performing KMeans Clustering to summaries the data in {n_cluster} points.'
        )
        #kmeans = KMeans(n_clusters=n_cluster).fit(repr)
        #self.centers = torch.tensor(kmeans.cluster_centers_).to(self.device)
        self.centers = torch.tensor(repr).to(self.device)
        logger.info(f'{self.centers.shape[0]} points successfully generated')
    def evaluate(self, dataset, last=True):
        """
        Evaluate the network with the porvided dataloader and return the accuracy score.
        """
        if last:
            logger = logging.getLogger()
            logger.info('>>> Start Evaluating the LeNet5.')

        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,
                                            num_workers=self.num_workers)

        N = len(loader)
        with torch.no_grad():
            pred, label, index = [], [], []
            for b, (input_data, input_label, idx) in enumerate(loader):
                input_data = input_data.float().to(self.device)
                input_label = input_label.to(self.device)
                idx = idx.to(self.device)
                # classify sample
                pred += self.net(input_data).argmax(dim=1).tolist()
                label += input_label.tolist()
                index += idx.tolist()

                print_progessbar(b, N, Name='Evaluation Batch', Size=40, erase=True)

            # compute accuracy
            acc = sklearn.metrics.accuracy_score(label, pred)

            if last:
                self.test_acc, self.test_pred = acc, (index, label, pred)
                logger.info(f'>>> Test accuracy {self.test_acc:.3%} \n')
            else:
                return acc, pred
    def initialize_projection_matrix(self, loader, net):
        """
        Initialize the subspace projection matrix from 2000 normal samples. The
        resulting matrix project an embeded point onto the estimated subspace of
        normal samples.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DeepSVDD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- eps (float) the epsilon representing the minimum value of the
            |           component of the center.
        OUTPUT
            |---- c (torch.Tensor) the initialized center.
        """
        N = 10000  #loader.dataset.__len__() # number of sample to use

        # Get S : matrix of sample embeding (M x N) with M the embeding dimension and N the number of samples
        S = []
        n_sample = 0
        net.eval()
        with torch.no_grad():
            for data in loader:
                input, _, mask, semi_label, _ = data
                input, mask, semi_label = input.to(
                    self.device).float(), mask.to(self.device), semi_label.to(
                        self.device)
                # mask input
                input = input * mask
                _, embed = net(input)
                embed = embed[
                    semi_label !=
                    -1, :]  # keep only embeding point of normal points
                S.append(embed)
                n_sample += embed.shape[0]

                if self.print_batch_progress:
                    print_progessbar(n_sample, N, Name='\t\tSamples', Size=20)

                if n_sample >= N:
                    break

        S = torch.cat(S, dim=0).transpose(0, 1)
        S = S.to(self.device)

        # compute P = S(S'S + lI)^-1 S'
        inv = torch.inverse(
            torch.matmul(S.transpose(0, 1), S) +
            1e-3 * torch.eye(S.shape[1], device=self.device))
        P = torch.matmul(S, torch.matmul(inv, S.transpose(0, 1)))

        return P
    def initialize_loss_scale_weight(self, loader, net, loss_fn_rec,
                                     loss_fn_ad):
        """
        Perform one forward pass to compute the reconstruction and embeding loss
        scaling factor so that the two loss have a magnitude of 1.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DMSAD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- loss_fn_rec (nn.Module) the reconstruction loss criterion.
            |---- loss_fn_ad (nn.Module) the MDAS embdeding loss criterion.
        OUTPUT
            |---- None
        """
        sumloss_rec, sumloss_ad = 0.0, 0.0
        n_batch = len(loader)

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                # mask input
                input = input * mask
                # Forward
                rec, embed = net(input)
                # Reconstruction loss only on normal sample (loss = 0 for abnormal)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != 1,
                    rec, input)
                loss_rec = loss_fn_rec(rec, input, mask)
                sumloss_rec += loss_rec.item()
                # DMSAD loss
                loss_ad = loss_fn_ad(embed, self.c, semi_label)
                sumloss_ad += loss_ad.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            # get the scale factors
            self.scale_rec = 1 / (sumloss_rec / n_batch)
            self.scale_ad = 1 / (sumloss_ad / n_batch)
Exemplo n.º 5
0
    def set_radius(self, loader, net):
        """
        compute radius as 1-gamma quatile of normal sample distance to center.
        Then anomaly score is ||net(x) - c_j||^2 - R_j^2 <--- negative if in, positive if out.
        ----------
        INPUT
            |---- loader (torch.utils.data.Dataloader) the loader of the data.
            |---- net (nn.Module) the DMSAD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        dist_list = [[] for _ in range(self.c.shape[0])
                     ]  # initialize N_sphere lists
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                # mask input and keep only normal samples
                input = (input * mask)[semi_label != -1]
                # get embdeding of batch
                _, embed = net(input)

                # get the closest sphere and count the number of normal samples per sphere
                dist, idx = torch.min(torch.norm(self.c.unsqueeze(0) -
                                                 embed.unsqueeze(1),
                                                 p=2,
                                                 dim=2),
                                      dim=1)
                for i, d in zip(idx, dist):
                    dist_list[i].append(d)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            # compute the radius as 1-gamma quantile of the normal distances of each spheres
            self.R = torch.zeros(self.c.shape[0], device=self.device)
            for i, dist in enumerate(dist_list):
                dist = torch.stack(dist, dim=0)
                self.R[i] = torch.kthvalue(dist,
                                           k=int((1 - self.gamma) *
                                                 dist.shape[0]))[0]
    def train(self, train_dataset, test_dataset):
        """
        Train the LeNet5 on the provided dataset.
        """
        logger = logging.getLogger()

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size,
                                            num_workers=self.num_workers)
        n_batch = len(train_loader)

        logger.info(f'>>> Start Training the LeNet5 with seed {self.seed}.')
        start_time = time.time()
        for epoch in range(self.n_epoch):
            epoch_start_time = time.time()
            epoch_loss = 0.0
            # minibatch iteration
            for b, (train_input, train_label, _) in enumerate(train_loader):
                train_input = train_input.float().to(self.device)
                train_input.require_grad = True
                train_label = train_label.to(self.device)
                # update weight by backpropagation
                pred = self.net(train_input)
                loss = self.loss_fn(pred, train_label)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

                print_progessbar(b, n_batch, Name='Train Batch', Size=40, erase=True)

            # evaluate on test set
            test_acc, _ = self.evaluate(test_dataset, last=False)
            # store epoch stat
            self.epoch_loss_list.append([epoch+1, epoch_loss / n_batch, test_acc])
            # print summary statistics
            logger.info(f'>>> | Epoch {epoch+1:03}/{self.n_epoch:03} '
                        f'| Loss {epoch_loss / n_batch:.7f} '
                        f'| Test Accuracy {test_acc:.3%} '
                        f'| Time {time.time() - epoch_start_time:.2f} [s] |')

            # update leanring rate
            self.scheduler.step()

        # Get results
        self.train_time = time.time() - start_time
        self.train_acc, _ = self.evaluate(train_dataset, last=False)

        logger.info(f'>>> Finished training of LeNet5')
        logger.info(f'>>> Train time {self.train_time:.0f} [s]')
        logger.info(f'>>> Train accuracy {self.train_acc:.3%}\n')
    def initialize_hypersphere_center(self, loader, net, eps=0.1):
        """
        Initialize the hypersphere center as the mean output of the network over
        one forward pass.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DeepSVDD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- eps (float) the epsilon representing the minimum value of the
            |           component of the center.
        OUTPUT
            |---- c (torch.Tensor) the initialized center.
        """
        n_sample = 0
        # get embdedding dimension with one forward pass of one batch
        with torch.no_grad():
            sample = next(iter(loader))[0].float()
            svdd_embedding_dim = net(sample.to(self.device))[1].shape[1]
        # initialize center
        c = torch.zeros(svdd_embedding_dim, device=self.device)

        # get the output of all samples and accumulate them
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, _, mask, _, _ = data
                input, mask = input.to(self.device).float(), mask.to(
                    self.device)
                # mask input
                input = input * mask

                _, embed = net(input)
                n_sample += embed.shape[0]
                c += torch.sum(embed, dim=0)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     loader.__len__(),
                                     Name='\t\tBatch',
                                     Size=20)

        # take the mean of accumulated c
        c /= n_sample
        # check if c_i are epsilon too close to zero to avoid them to be trivialy matched to zero
        c[(torch.abs(c) < eps) & (c < 0)] = -eps
        c[(torch.abs(c) < eps) & (c > 0)] = eps

        return c
Exemplo n.º 8
0
def test_AE(ae_net,
            dataset,
            batch_size=16,
            n_jobs_dataloader=4,
            device='cuda'):
    # make test dataloader using image and mask
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \
                                    shuffle=True, num_workers=n_jobs_dataloader)

    # MSE loss without reduction --> MSE loss for each output pixels
    criterion = MaskedMSELoss(reduction='none')

    # set to device
    ae_net = ae_net.to(device)
    criterion = criterion.to(device)

    # Testing
    epoch_loss = 0.0
    n_batch = 0
    start_time = time.time()
    idx_label_score = []
    # put network in evaluation mode
    ae_net.eval()

    with torch.no_grad():
        for b, data in enumerate(loader):
            input, label, mask, _, idx = data
            # put inputs to device
            input, label = input.to(device).float(), label.to(device)
            mask, idx = mask.to(device), idx.to(device)

            rec = ae_net(input)
            rec_loss = criterion(rec, input, mask)
            score = torch.mean(rec_loss, dim=tuple(range(
                1, rec.dim())))  # mean over all dimension per batch

            # append scores and label
            idx_label_score += list(
                zip(idx.cpu().data.numpy().tolist(),
                    label.cpu().data.numpy().tolist(),
                    score.cpu().data.numpy().tolist()))
            # overall batch loss
            loss = torch.sum(rec_loss) / torch.sum(mask)
            epoch_loss += loss.item()
            n_batch += 1

            print_progessbar(b, loader.__len__(), Name='\t\tBatch', Size=20)

    test_time = time.time() - start_time
    scores = idx_label_score
    return test_time, scores
Exemplo n.º 9
0
    def initialize_loss_scale_weight(self, loader, net, criterion_rec, criterion_ad):
        """
        Perform one forward pass to compute the reconstruction and embdeding loss
        scalling factors to get a loss in the magnitude of 1.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DMSVDD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- criterion_rec (nn.Module) the reconstruction loss criterion.
            |---- criterion_ad (nn.Module) the MSVDD embdeding loss criterion.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()
        logger.info('>>> Initializing the loss scale factors.')
        sumloss_rec = 0.0
        sumloss_ad = 0.0
        n_batch = 0
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device)
                # mask input
                input = input * mask
                # forward
                rec, embed = net(input)
                # compute rec loss
                rec = torch.where(semi_label.view(-1,1,1,1).expand(*input.shape) != -1, rec, input) # ignore knwon abnormal samples
                loss_rec = criterion_rec(rec, input, mask)

                sumloss_rec += loss_rec.item()
                # compute ad loss
                loss_ad = criterion_ad(embed, self.c, self.R)
                sumloss_ad += loss_ad.item()

                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, loader.__len__(), Name='\t\tBatch', Size=20)

            # initialize the scalling weight of the reconstruction loss so that it's 1 at epoch 1
            self.scale_rec = 1 / (sumloss_rec / n_batch)
            self.scale_em = 1 / (sumloss_ad / n_batch)
            logger.info(f'>>> reconstruction loss scale factor initialized to {self.scale_rec:.6f}')
            logger.info(f'>>> MSVDD embdeding loss scale factor initialized to {self.scale_em:.6f}')
Exemplo n.º 10
0
    def initialize_hypersphere_center(self, loader, net, eps=0.1):
        """
        Initialize the hypersphere center as the mean output of the network over
        one forward pass.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DeepSAD network. The output must be a vector
            |           embedding of the input.
            |---- eps (float) the epsilon representing the minimum value of the
            |           component of the center.
        OUTPUT
            |---- c (torch.Tensor) the initialized center.
        """
        n_sample = 0
        net.eval()
        with torch.no_grad():
            # get embdedding dimension with one forward pass of one batch
            sample = next(iter(loader))[0].float()
            embed_dim = net(sample.to(self.device))[1].shape[1]
            # initialize c
            c = torch.zeros(embed_dim, device=self.device)
            # get the output of all samples and accumulate them
            for b, data in enumerate(loader):
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                #mask input and take normal samples only
                input = (input * mask)[semi_label != -1]

                _, embed = net(input)
                n_sample += embed.shape[0]
                c += torch.sum(embed, dim=0)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\t Center Initialization Batch',
                                     Size=40,
                                     erase=True)

        # take the mean of accumulated c
        c /= n_sample
        # check if c_i are epsilon too close to zero to avoid them to be rivialy matched to zero
        c[(torch.abs(c) < eps) & (c < 0)] = -eps
        c[(torch.abs(c) < eps) & (c > 0)] = eps

        return c
Exemplo n.º 11
0
    def initialize_centers(self, loader, net, eps=0.1):
        """
        Initialize the multiple centers using the K-Means algorithm on the
        embedding of all the normal samples.
        ----------
        INPUT
            |---- loader (torch.utils.data.Dataloader) the loader of the data.
            |---- net (nn.Module) the DMSAD network. The output must be a vector
            |           embedding of the input. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- eps (float) minimal value for center coordinates, to avoid
            |           center too close to zero.
        OUTPUT
            |---- None
        """
        # Get sample embedding
        repr = []
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                # mask input and keep only normal samples
                input = (input * mask)[semi_label != -1]
                # get embdeding of batch
                _, embed = net(input)
                repr.append(embed)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            repr = torch.cat(repr, dim=0).cpu().numpy()

        # Apply Kmeans algorithm on embedding
        kmeans = KMeans(n_clusters=self.n_sphere_init).fit(repr)
        self.c = torch.tensor(kmeans.cluster_centers_).to(self.device)

        # check if c_i are epsilon too close to zero to avoid them to be trivialy matched to zero
        self.c[(torch.abs(self.c) < eps) & (self.c < 0)] = -eps
        self.c[(torch.abs(self.c) < eps) & (self.c > 0)] = eps
    def stabilize_BN(self, dataset, rep=5):
        """
        Perform few forward passes on discriminator and generator in train mode but without grad to stabilize batch-norm
        parameters (mean and std).
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) dataset returning a pair (image, mask) It must return an image tensor
            |               of dimension [C, H, W], a mask tensor of dimension [1, H, W].
            |---- rep (int) number of time to go over dataset.
        OUTPUT
            |---- None
        """
        with torch.no_grad():
            self.generator.train()
            self.discriminator.train()
            # make loader
            loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                worker_init_fn=lambda _: np.random.seed())
            n_batch = len(loader)
            # validate data by batch
            for i in range(rep):
                for b, data in enumerate(loader):
                    im, mask = data
                    im = im.to(self.device).float()
                    mask = mask.to(self.device).float()

                    # inpaint
                    _, _ = self.generator(im, mask)
                    _ = self.discriminator(im, mask)
                    # recover non-masked regions
                    print_progessbar(
                        b,
                        n_batch,
                        Name=f'Stabilization Batch (iteration {i+1})',
                        Size=40,
                        erase=True)
Exemplo n.º 13
0
    def initialize_hypersphere_center(self, loader, net, eps=0.1):
        """
        Initialize the hypersphere center as the mean output of the network over
        one forward pass.
        ----------
        INPUT
            |---- loader (torch.utils.data.DataLoader) the loader of the data.
            |---- net (nn.Module) the DeepSAD network. The output must be a vector
            |           embedding of the input.
            |---- eps (float) the epsilon representing the minimum value of the
            |           component of the center.
        OUTPUT
            |---- c (torch.Tensor) the initialized center.
        """
        n_sample = 0
        c = torch.zeros(net.embed_dim, device=self.device)
        # get the output of all samples and accumulate them
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, _, _, _, _ = data
                input = input.to(self.device).float()
                output = net(input)
                n_sample += output.shape[0]
                c += torch.sum(output, dim=0)

                if self.print_batch_progress:
                    print_progessbar(b,
                                     loader.__len__(),
                                     Name='\t\tBatch',
                                     Size=20)

        # take the mean of accumulated c
        c /= n_sample
        # check if c_i are epsilon too close to zero to avoid them to be rivialy matched to zero
        c[(torch.abs(c) < eps) & (c < 0)] = -eps
        c[(torch.abs(c) < eps) & (c > 0)] = eps

        return c
    def evaluate(self, dataset, n_neighbors=10, mode='test'):
        """
        Evaluate the representation classification capabilitities by using as score
         the mean distance to the n_neighbors nearest centers.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is evaluated.
            |---- n_neighbors (int) the number of nearest centers to consider.
            |---- mode (str) define the set used. Either validation or test.
        OUTPUT
            |---- None
        """
        assert mode in [
            'valid', 'test'
        ], f'Invalid mode provided : {mode} was given. Expected either valid or test.'

        logger = logging.getLogger()
        # make dataloader
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             num_workers=self.n_job_dataloader)

        logger.info('Start Evaluating the Ablation Study.')
        idx_label_score = []

        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data on device
                input, label, mask, semi_label, idx = data
                input = input.to(self.device).float()
                label = label.to(self.device)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask input
                input = input * mask

                # Embed input and compute anomaly  score
                embed = self.net(
                    input
                )[0]  # first element returned is the transfered representation

                # get 10 nearest centers of samples and take mean(distance as scores)
                dist = torch.norm(self.centers.unsqueeze(0) -
                                  embed.unsqueeze(1),
                                  p=2,
                                  dim=2)
                min_dist = dist.topk(n_neighbors, dim=1, largest=False)[0]
                score = torch.mean(min_dist, dim=1)

                # append idx, scores, label and embeding
                idx_label_score += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        score.cpu().data.numpy().tolist()))

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\t Evaluation Batch',
                                     Size=40,
                                     erase=True)

        # compute AUCs
        _, label, score = zip(*idx_label_score)
        label, score = np.array(label), np.array(score)
        auc = roc_auc_score(label, score)
        self.results[mode]['auc'] = auc
        self.results[mode]['score'] = idx_label_score

        logger.info(f'{mode.title()} AUC : {auc:.3%}')
        logger.info('Finished Evaluating the Ablation Study.')
    def pretrain(self, net, dataset):
        """
        Pretrain the AE for the joint DMSAD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is pretrained. It must return a tuple (image,
            |           label, mask, semi-supervized labels, idx).
            |---- net (nn.Module) The DMSAD to pretrain. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSAD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide only the reconstruction
        net.return_svdd_embed = False

        # define the reconstruvtion loss function
        loss_fn_rec = MaskedMSELoss()

        # define the optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Start training
        logger.info(' Start Pretraining the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = train_loader.__len__()

        for epoch in range(self.n_epoch_pretrain):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get batch data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # mask the input and keep only normal samples
                input = (input * mask)[semi_label != -1]
                mask = mask[semi_label != -1]

                # Update network parameters via backpropagation : forward + backward + optim
                optimizer.zero_grad()
                rec, _ = net(input)
                loss = loss_fn_rec(rec, input, mask)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tBatch',
                                     Size=40,
                                     erase=True)

            # print epoch statstics
            logger.info(
                f'----| Epoch {epoch + 1:03}/{self.n_epoch_pretrain:03} '
                f'| Pretrain Time {time.time() - epoch_start_time:.3f} [s] '
                f'| Pretrain Loss {epoch_loss / n_batch:.6f} |')
            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

        # End training
        self.pretrain_loss = epoch_loss_list
        self.pretrain_time = time.time() - start_time
        logger.info(
            f'---- Finished Pretraining the AutoEncoder in {self.pretrain_time:.3f} [s].'
        )

        return net
    def train(self, net, dataset, valid_dataset=None):
        """
        Train the DMSAD on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DMSAD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset on which
            |           to validate the model at each epoch. No validation is
            |           performed if not provided.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSAD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide both the reconstruction and the embedding
        net.return_svdd_embed = True

        # Initialize the hyper-sphere centers by Kmeans
        if self.c is None:
            logger.info(' Initializing the hypersheres centers.')
            self.initialize_centers(train_loader, net)
            logger.info(
                f' {self.c.shape[0]} centers successfully initialized.')

        # define the reconstruvtion loss function
        loss_fn_rec = MaskedMSELoss()
        loss_fn_ad = DMSADLoss(self.eta, eps=self.eps)

        # Compute the scaling factors for the reconstruction and DMSAD losses
        logger.info(' Initializing the loss scale factors.')
        self.initialize_loss_scale_weight(train_loader, net, loss_fn_rec,
                                          loss_fn_ad)
        logger.info(
            f' reconstruction loss scale factor initialized to {self.scale_rec:.6f}'
        )
        logger.info(
            f' MSAD embdeding loss scale factor initialized to {self.scale_ad:.6f}'
        )

        # define the optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define the learning rate scheduler : 90% reduction at each steps
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info(' Start Training Jointly the DMSAD and the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()
            n_k = torch.zeros(self.c.shape[0], device=self.device)

            for b, data in enumerate(train_loader):
                # get batch data
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # mask input
                input = input * mask

                # Update the network by backpropagation using the two losses.
                optimizer.zero_grad()
                rec, embed = net(input)
                # reconstruction loss only on normal sample (loss of zero for abnormal)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != 1,
                    rec, input)
                loss_rec = self.scale_rec * self.criterion_weight[
                    0] * loss_fn_rec(rec, input, mask)
                # DMSAD loss
                loss_ad = self.scale_ad * self.criterion_weight[
                    1] * loss_fn_ad(embed, self.c, semi_label)
                # total loss
                loss = loss_rec + loss_ad
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                # get the closest sphere and count the number of normal samples per sphere
                idx = torch.argmin(torch.norm(self.c.unsqueeze(0) -
                                              embed.unsqueeze(1),
                                              p=2,
                                              dim=2),
                                   dim=1)
                for i in idx[semi_label != -1]:
                    n_k[i] += 1

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            # remove centers with less than gamma fraction of largest hypersphere number of sample
            self.c = self.c[n_k >= self.gamma * torch.max(n_k)]

            # intermediate validation of the model if required
            valid_auc = ''
            if valid_dataset:
                auc_rec, auc_ad = self.evaluate(net,
                                                valid_dataset,
                                                mode='valid',
                                                final=False)
                valid_auc = f' Rec AUC {auc_rec:.3%} | MSAD AUC {auc_ad:.3%} |'

            # print epoch statstics
            logger.info(
                f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train Time {time.time() - epoch_start_time:.3f} [s] '
                f'| Train Loss {epoch_loss / n_batch:.6f} '
                f'| N sphere {self.c.shape[0]:03} |' + valid_auc)
            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update learning rate if milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

            # re-initialized loss scale factors after few epochs when the centers are more or less defined
            if epoch + 1 == self.reset_scaling_epoch:
                with torch.no_grad():
                    # Compute the scaling factors for the reconstruction and DMSAD losses
                    logger.info('---- Reinitializing the loss scale factors.')
                    self.initialize_loss_scale_weight(train_loader, net,
                                                      loss_fn_rec, loss_fn_ad)
                    logger.info(
                        f'---- reconstruction loss scale factor reinitialized to {self.scale_rec:.6f}'
                    )
                    logger.info(
                        f'---- MSAD embdeding loss scale factor reinitialized to {self.scale_ad:.6f}'
                    )

        # Set the radius of each sphere as 1-gamma quantile of normal samples distances
        logger.info(
            f'---- Setting the hyperspheres radii as the {1-self.gamma:.1%} quantiles of normal sample distances.'
        )
        self.set_radius(train_loader, net)
        logger.info(f'---- {self.R.shape[0]} radii successufully defined.')

        # End Training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'---- Finished jointly training the DMSAD and the Autoencoder in {self.train_time:.3f} [s].'
        )

        return net
Exemplo n.º 17
0
    def train(self, dataset, net, valid_dataset=None, checkpoint_path=None):
        """
        Train the DMSAD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, label, mask
            |           semi-supervized labels and the index.
            |---- net (nn.Module) The DMSAD to train.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset on which
            |           to validate the network at each epoch. Not validated if
            |           not provided.
        OUTPUT
            |---- net (nn.Module) The trained DMSAD.
        """
        logger = logging.getLogger()

        # make the train dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_job_dataloader)

        # put net to device
        net = net.to(self.device)

        # initialize hypersphere center
        if self.c is None:
            logger.info('Initializing the hypersphere centers.')
            self.initialize_centers(train_loader, net)
            logger.info(f'{self.c.shape[0]} centers successfully initialized.')

        # define loss criterion
        loss_fn = DMSADLoss(self.eta, eps=self.eps)

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Load checkpoint if any
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            n_epoch_finished = checkpoint['n_epoch_finished']
            net.load_state_dict(checkpoint['net_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            self.c = checkpoint['centers']
            logger.info(
                f'Checkpoint loaded with {n_epoch_finished} epochs finished.')
        except FileNotFoundError:
            logger.info('Training from scratch.')
            n_epoch_finished = 0

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('Start Training the DMSAD.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(n_epoch_finished, self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()
            n_k = torch.zeros(self.c.shape[0], device=self.device)

            for b, data in enumerate(train_loader):
                # get input and semi-supervized labels
                input, _, mask, semi_label, _ = data
                # put them to device
                input = input.to(self.device).float().requires_grad_(True)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)

                # mask input
                input = input * mask

                # zero the network's gradients
                optimizer.zero_grad()
                # optimize by backpropagation
                _, embed = net(input)
                loss = loss_fn(embed, self.c, semi_label)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                # get the closest sphere and count the number of normal samples per sphere
                idx = torch.argmin(torch.norm(self.c.unsqueeze(0) -
                                              embed.unsqueeze(1),
                                              p=2,
                                              dim=2),
                                   dim=1)
                for i in idx[semi_label != -1]:
                    n_k[i] += 1

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(train_loader),
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            # remove centers with less than gamma fraction of largest hypersphere number of sample
            self.c = self.c[n_k >= self.gamma * torch.max(n_k)]

            # validate if required
            valid_auc = ''
            if valid_dataset:
                auc = self.evaluate(net,
                                    valid_dataset,
                                    return_auc=True,
                                    print_to_logger=False,
                                    save_tSNE=False)
                valid_auc = f' Valid AUC {auc:.3%} |'

            # log the epoch statistics
            logger.info(
                f'----| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train Time: {time.time() - epoch_start_time:.3f} [s] '
                f'| Train Loss: {epoch_loss / n_batch:.6f} '
                f'| N sphere {self.c.shape[0]:03} |' + valid_auc)

            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update scheduler
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

            # Save checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                checkpoint = {
                    'n_epoch_finished': epoch + 1,
                    'net_state': net.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'centers': self.c
                }
                torch.save(checkpoint, checkpoint_path)
                logger.info('---- Checkpoint Saved.')

        # Set the radius of each sphere as 1-gamma quantile of normal samples distances
        logger.info(
            f'---- Setting the hyperspheres radii as the {1-self.gamma:.1%} quantiles of normal sample distances.'
        )
        self.set_radius(train_loader, net)
        logger.info(f'---- {self.R.shape[0]} radii successufully defined.')

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'---- Finished Training DMSAD in {self.train_time:.3f} [s]')

        return net
Exemplo n.º 18
0
    def evaluate(self,
                 net,
                 dataset,
                 return_auc=False,
                 print_to_logger=True,
                 save_tSNE=True):
        """
        Evaluate the DSAD network on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DeepSAD network to validate.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is evaluated.
            |---- net (nn.Module) The DeepSAD network to validate.
            |---- return_auc (bool) whether to return the computed auc or not.
            |---- print_to_logger (bool) whether to print in the logger.
            |---- save_tSNE (bool) whether to save a 2D t-SNE representation of
            |           the embeded data points
        OUTPUT
            |---- None
        """
        if print_to_logger:
            logger = logging.getLogger()

        # make dataloader
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             num_workers=self.n_job_dataloader)
        # put net on device
        net = net.to(self.device)

        # Evaluating
        if print_to_logger:
            logger.info('Start Evaluating the DMSAD.')
        start_time = time.time()
        idx_label_score = []

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data on device
                input, label, mask, semi_label, idx = data
                input = input.to(self.device).float()
                label = label.to(self.device)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask input
                input = input * mask

                # Embed input and compute anomaly  score
                _, embed = net(input)
                # find closest sphere
                dist, sphere_idx = torch.min(torch.norm(self.c.unsqueeze(0) -
                                                        embed.unsqueeze(1),
                                                        p=2,
                                                        dim=2),
                                             dim=1)

                # if self.R is not None:
                #     # anomaly scores positive if dist > R and negative if dist < R
                #     score = dist - torch.stack([self.R[j] for j in sphere_idx], dim=0)
                # else:
                #     # else scores is just the minimal distance to a center
                score = dist

                # append idx, scores, label and embeding
                idx_label_score += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        score.cpu().data.numpy().tolist(),
                        sphere_idx.cpu().data.numpy().tolist(),
                        embed.cpu().data.numpy().tolist()))

                if self.print_batch_progress:
                    print_progessbar(b,
                                     len(loader),
                                     Name='\t\t Evaluation Batch',
                                     Size=40,
                                     erase=True)

        # compute AUCs
        index, label, score, sphere_index, embed = zip(*idx_label_score)
        label, score = np.array(label), np.array(score)
        auc = roc_auc_score(label, score)

        if save_tSNE:
            embed = np.array(embed)
            embed = TSNE(n_components=2).fit_transform(embed)
            idx_label_score = list(
                zip(index, label.tolist(), score.tolist(), sphere_index,
                    embed.tolist()))

        self.eval_time = time.time() - start_time
        self.eval_scores = idx_label_score
        self.eval_auc = auc

        if print_to_logger:
            logger.info(f'Evaluation Time : {self.eval_time}')
            logger.info(f'Evaluation AUC : {self.eval_auc:.3%}')
            logger.info('Finished Evaluating the DMSAD.')

        if return_auc:
            return auc
    def evaluate(self, net, dataset, mode='test', final=False):
        """
        Evaluate the model with the given dataset.
        ----------
        INPUT
            |---- net (nn.Module) The DMSAD to validate. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image, a mask and
            |           semi-supervized labels.
            |---- mode (str) either 'valid' or 'test'. Define the evaluation mode.
            |           In 'valid' the evaluation can return the reconstruction
            |           and MSAD AUCs and compute the best threshold to maximize
            |           the F1-scores. In test mode the validation threshold is
            |           used to compute the F1-score.
            |---- final (bool) whether the call represents the final validation,
            |           in which case the validation results are saved. Only
            |           relevant if mode is 'valid'.
        OUTPUT
            |---- auc (tuple (reconstruction auc, ad auc)) the validation AUC for
            |           both scores are return only if final is False. Else None
            |           is return.
        """
        assert mode in [
            'valid', 'test'
        ], f'Mode {mode} is not supported. Should be either "valid" or "test".'
        logger = logging.getLogger()

        # make the dataloader
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_jobs_dataloader)
        # put net on device
        net = net.to(self.device)
        # set the network to provide both the reconstruction and the embedding
        net.return_svdd_embed = True

        # define the two loss function
        loss_fn_rec = MaskedMSELoss(
            reduction='none'
        )  # no reduction to compute AD score for each sample
        loss_fn_ad = DMSADLoss(self.eta, self.eps)

        # Validate
        if final or mode == 'test':
            logger.info(
                f' Start Evaluating the jointly trained DMSAD and AutoEncoder in {mode} mode.'
            )
        epoch_loss = 0.0
        n_batch = len(loader)
        start_time = time.time()
        idx_label_score_rec, idx_label_score_ad = [], [
        ]  # placeholder for scores

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                # get data on device
                input, label, mask, semi_label, idx = data
                input = input.to(self.device).float()
                label = label.to(self.device)
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask input
                input = input * mask

                # compute the loss
                rec, embed = net(input)
                loss_rec = loss_fn_rec(rec, input, mask)
                loss_ad = loss_fn_ad(embed, self.c, semi_label)
                # get reconstruction anomaly scores : mean loss by sample
                rec_score = torch.mean(loss_rec,
                                       dim=tuple(range(1, rec.dim())))

                # find closest sphere
                dist, sphere_idx = torch.min(torch.norm(self.c.unsqueeze(0) -
                                                        embed.unsqueeze(1),
                                                        p=2,
                                                        dim=2),
                                             dim=1)

                if not self.R is None:
                    # anomaly scores positive if dist > R and negative if dist < R
                    ad_score = dist - torch.stack(
                        [self.R[j] for j in sphere_idx], dim=0)
                else:
                    # else scores is just the minimal distance to a center
                    ad_score = dist

                # append scores to the placeholer lists
                idx_label_score_rec += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist(),
                        sphere_idx.cpu().data.numpy().tolist(),
                        embed.cpu().data.numpy().tolist()))

                # compute the overall loss
                loss = self.scale_rec * self.criterion_weight[0] * (
                    torch.sum(loss_rec) / torch.sum(mask))
                loss += self.scale_ad * self.criterion_weight[1] * loss_ad
                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\t Evaluation Batch',
                                     Size=40,
                                     erase=True)

        # compute AUCs
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        auc_rec = roc_auc_score(label, rec_score)

        _, label, ad_score, _, _ = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        auc_ad = roc_auc_score(label, ad_score)

        if mode == 'valid':
            if final:
                # save results
                self.valid_time = time.time() - start_time
                self.valid_scores_rec = idx_label_score_rec
                self.valid_auc_rec = auc_rec
                self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(
                    rec_score, label, metric=f1_score)
                self.valid_scores_ad = idx_label_score_ad
                self.valid_auc_ad = auc_ad
                self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(
                    ad_score, label, metric=f1_score)
                # print infos
                logger.info(f'---- Validation Time: {self.valid_time:.3f} [s]')
                logger.info(
                    f'---- Validation Loss: {epoch_loss / n_batch:.6f}')
                logger.info(
                    f'---- Validation reconstruction AUC: {self.valid_auc_rec:.3%}'
                )
                logger.info(
                    f'---- Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}'
                )
                logger.info(
                    f'---- Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}'
                )
                logger.info(
                    f'---- Validation MSAD AUC: {self.valid_auc_ad:.3%}')
                logger.info(
                    f'---- Best Threshold for the MSAD score maximizing F1-score: {self.scores_threhold_ad:.3f}'
                )
                logger.info(
                    f'---- Best F1-score on MSAD score: {self.valid_f1_ad:.3%}'
                )
                logger.info(
                    '---- Finished validating the Joint DMSAD and AutoEncoder.\n'
                )
            else:
                return auc_rec, auc_ad

        elif mode == 'test':
            # save results
            self.test_time = time.time() - start_time
            self.test_scores_rec = idx_label_score_rec
            self.test_auc_rec = auc_rec
            self.test_scores_ad = idx_label_score_ad
            self.test_auc_ad = auc_ad

            # print infos
            logger.info(f'---- Test Time: {self.test_time:.3f} [s]')
            logger.info(f'---- Test Loss: {epoch_loss / n_batch:.6f}')
            logger.info(
                f'---- Test reconstruction AUC: {self.test_auc_rec:.3%}')
            if self.scores_threhold_rec is not None:
                self.test_f1_rec = f1_score(
                    label, np.where(rec_score > self.scores_threhold_rec, 1,
                                    0))
                logger.info(
                    f'---- Best F1-score on reconstruction score: {self.test_f1_rec:.3%}'
                )
            logger.info(f'---- Test MSAD AUC: {self.test_auc_ad:.3%}')
            if self.scores_threhold_ad is not None:
                self.test_f1_ad = f1_score(
                    label, np.where(ad_score > self.scores_threhold_ad, 1, 0))
                logger.info(
                    f'---- Best F1-score on MSAD score: {self.test_f1_ad:.3%}')
            logger.info(
                '---- Finished testing the Joint DMSAD and AutoEncoder.\n')
Exemplo n.º 20
0
    def evaluate(self,
                 dataset,
                 net,
                 save_tSNE=False,
                 return_loss=True,
                 print_to_logger=True):
        """
        Evaluate the Contrative network on the provided dataset.
        ----------
        INPUT
            |---- net (nn.Module) The Encoder network to validate.
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is evaluated.
            |---- print_to_logger (bool) whether to print in the logger.
            |---- save_tSNE (bool) whether to save a 2D t-SNE representation of
            |           the embeded data points
            |---- return_loss (bool) whether to return the validation loss.
        OUTPUT
            |---- (auc) (float) the validation loss if required.
        """
        if print_to_logger:
            logger = logging.getLogger()
        # make dataloader (with drop_last = True to ensure that the loss can be computed)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             num_workers=self.n_job_dataloader,
                                             drop_last=True)

        # put net on device
        net = net.to(self.device)

        # define loss function
        loss_fn = InfoNCE_loss(self.tau, self.batch_size, device=self.device)

        if print_to_logger:
            logger.info("Start Evaluating Contrastive.")

        net.eval()
        with torch.no_grad():
            sum_loss = 0.0
            idx_h_z = []
            n_batch = len(loader)

            for b, data in enumerate(loader):
                # get input
                input_1, input_2, _, idx = data
                input_1 = input_1.to(self.device).float()
                input_2 = input_2.to(self.device).float()
                idx = idx.to(self.device)
                # forward
                h_1, z_1 = net(input_1)
                h_2, z_2 = net(input_2)
                # normalize
                z_1 = F.normalize(z_1, dim=1)
                z_2 = F.normalize(z_2, dim=1)
                # compute loss
                loss = loss_fn(z_1, z_2)
                sum_loss += loss.item()
                # save embeddings
                if save_tSNE:
                    idx_h_z += list(
                        zip(idx.cpu().data.numpy().tolist(),
                            h_1.cpu().data.numpy().tolist(),
                            z_1.cpu().data.numpy().tolist()))

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tEvaluation Batch',
                                     Size=40,
                                     erase=True)

        if save_tSNE:
            if print_to_logger:
                logger.info("Computing the t-SNE representation.")
            # Apply t-SNE transform on embeddings
            index, h, z = zip(*idx_h_z)
            h, z = np.array(h), np.array(z)
            h = TSNE(n_components=2).fit_transform(h)
            z = TSNE(n_components=2).fit_transform(z)
            self.eval_repr = list(zip(index, h.tolist(), z.tolist()))

            if print_to_logger:
                logger.info("Succesfully computed the t-SNE representation ")

        if return_loss:
            return loss / n_batch
Exemplo n.º 21
0
    def test(self, dataset, net):
        """
        Test the DROCC network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is tested. It must return an image and
            |           semi-supervised labels.
            |---- net (nn.Module) The DROCC to test. The network should return
            |           the logit of the passed sample.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)

        # loss function
        criterion = nn.BCEWithLogitsLoss()

        # Testing
        logger.info('>>> Start Testing of the DROCC.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = test_loader.__len__()
        start_time = time.time()
        idx_label_score = []

        net.eval()
        with torch.no_grad():
            for b, data in enumerate(test_loader):
                input, label, mask, _, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(
                    self.device).float()
                idx, mask = idx.to(self.device), mask.to(self.device)
                # mask input
                input = input * mask

                logit = net(input).squeeze(dim=1)
                loss = criterion(logit, label)
                # get anomaly scores
                ad_score = torch.sigmoid(
                    logit
                )  # sigmoid of logit : should be high for abnormal (target = 1) and low for normal (target = 0)

                idx_label_score += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.test_time = time.time() - start_time
        self.test_scores = idx_label_score
        _, label, ad_score = zip(*idx_label_score)
        label, ad_score = np.array(label), np.array(ad_score)
        self.test_auc = roc_auc_score(label, ad_score)
        self.test_f1 = f1_score(
            label, np.where(ad_score > self.scores_threshold, 1, 0))

        # add info to logger
        logger.info(f'>>> Testing Time: {self.test_time:.3f} [s]')
        logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Test AUC: {self.test_auc:.3%}')
        logger.info(f'>>> Test F1-score: {self.test_f1:.3%}')
        logger.info('>>> Finished testing the DROCC.\n')
    def test(self, dataset, net):
        """
        Test the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is tested. It must return an image and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to test. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss(reduction='none')
        criterion_ad = self.SVDDLoss(self.space_repr,
                                     self.nu,
                                     eps=self.eps,
                                     soft_boundary=self.soft_boundary)

        # Testing
        logger.info('>>> Start Testing the joint DeepSVDD and AutoEncoder.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = test_loader.__len__()
        start_time = time.time()
        idx_label_score_rec = []
        idx_label_score_ad = []
        # put network in evaluation mode
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(test_loader):
                input, label, mask, semi_label, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(
                    self.device)
                mask, semi_label = mask.to(self.device), semi_label.to(
                    self.device)
                idx = idx.to(self.device)

                # mask the input
                input = input * mask

                # compute loss
                rec, embed = net(input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_ad = criterion_ad(embed, self.R)
                # compute anomaly scores
                rec_score = torch.mean(
                    loss_rec, dim=tuple(range(
                        1, rec.dim())))  # mean over all dimension per batch
                if self.use_subspace:
                    dist = torch.sum(
                        (embed -
                         torch.matmul(self.space_repr, embed.transpose(
                             0, 1)).transpose(0, 1))**2,
                        dim=1
                    )  # score is the distance (large distances highlight anomalies)
                else:
                    dist = torch.sum(
                        (embed - self.space_repr)**2, dim=1
                    )  # score is the distance (large distances highlight anomalies)

                if self.soft_boundary:
                    ad_score = dist - self.R**2
                else:
                    ad_score = dist

                # get overall loss
                mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask)
                loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec
                loss += self.scale_em * self.criterion_weight[1] * loss_ad

                # append scores and label
                idx_label_score_rec += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(
                    zip(idx.cpu().data.numpy().tolist(),
                        label.cpu().data.numpy().tolist(),
                        ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.test_time = time.time() - start_time
        self.test_scores_rec = idx_label_score_rec
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        self.test_auc_rec = roc_auc_score(label, rec_score)
        self.test_f1_rec = f1_score(
            label, np.where(rec_score > self.scores_threhold_rec, 1, 0))

        self.test_scores_ad = idx_label_score_ad
        _, label, ad_score = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        self.test_auc_ad = roc_auc_score(label, ad_score)
        self.test_f1_ad = f1_score(
            label, np.where(ad_score > self.scores_threhold_ad, 1, 0))

        # add info to logger
        logger.info(f'>>> Test Time: {self.test_time:.3f} [s]')
        logger.info(f'>>> Test Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Test reconstruction AUC: {self.test_auc_rec:.3%}')
        logger.info(
            f'>>> Test F1-score on reconstruction score: {self.test_f1_rec:.3%}'
        )
        logger.info(f'>>> Test AD AUC: {self.test_auc_ad:.3%}')
        logger.info(
            f'>>> Test F1-score on DeepSVDD score: {self.test_f1_ad:.3%}')
        logger.info(
            '>>> Finished Testing the Joint DeepSVDD and AutoEncoder.\n')
Exemplo n.º 23
0
    def train(self, dataset, net, valid_dataset=None):
        """
        Train the network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return two transformed version
            |           of an image.
            |---- net (nn.Module) The Encoder to train.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset on which
            |           to validate the network at each epoch. Not validated if
            |           not provided.
        OUTPUT
            |---- net (nn.Module) The trained Encoder.
        """
        logger = logging.getLogger()

        # make dataloader (with drop_last = True to ensure that the loss can be computed)
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_job_dataloader,
            drop_last=True)

        # put net on device
        net = net.to(self.device)

        # define loss function
        loss_fn = InfoNCE_loss(self.tau, self.batch_size, device=self.device)

        # define the optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define the learning rate scheduler : 90% reduction at each steps
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Start Contrastive Training.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get data on device
                input_1, input_2, _, _ = data
                input_1 = input_1.to(self.device).float().requires_grad_(True)
                input_2 = input_2.to(self.device).float().requires_grad_(True)

                # Update by Backpropagation : Fowrad + Backward + step
                optimizer.zero_grad()
                _, z_1 = net(input_1)
                _, z_2 = net(input_2)

                # normalize embeddings
                z_1 = F.normalize(z_1, dim=1)
                z_2 = F.normalize(z_2, dim=1)

                loss = loss_fn(z_1, z_2)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\tTrain Batch',
                                     Size=40,
                                     erase=True)

            # compute valid_loss if required
            valid_loss = ''
            if valid_dataset:
                loss = self.evaluate(valid_dataset,
                                     net,
                                     save_tSNE=False,
                                     return_loss=True,
                                     print_to_logger=False)
                valid_loss = f' Valid Loss {loss:.6f} |'

            # display epoch statistics
            logger.info(f'----| Epoch {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Time {time.time() - epoch_start_time:.3f} [s]'
                        f'| Loss {epoch_loss / n_batch:.6f} |' + valid_loss)
            # store loss
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update learning rate if milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestones:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # Save results
        self.train_time = time.time() - start_time
        self.train_loss = epoch_loss_list
        logger.info(
            f'---- Finished Contrastive Training in {self.train_time:.3f} [s].'
        )

        return net
    def train(self, dataset, net, valid_dataset=None):
        """
        Train the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The trained joint DeepSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True  # enable the network to provide the SVDD embdeding

        # initialize hypersphere center or subspace projection matrix
        if self.space_repr is None:
            if self.use_subspace:
                logger.info('>>> Initializing the subspace projection matrix.')
                self.space_repr = self.initialize_projection_matrix(
                    train_loader, net)
                logger.info('>>> Projection matrix succesfully initialized.')
            else:
                logger.info('>>> Initializing the hypersphere center.')
                self.space_repr = self.initialize_hypersphere_center(
                    train_loader, net)
                logger.info('>>> Center succesfully initialized.')

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss()
        criterion_ad = self.SVDDLoss(self.space_repr,
                                     self.nu,
                                     eps=self.eps,
                                     soft_boundary=self.soft_boundary)

        # compute the scale weight so that the rec and svdd losses are scalled and comparable
        self.initialize_loss_scale_weight(train_loader, net, criterion_rec,
                                          criterion_ad)

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the Joint DeepSVDD and Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()
            dist = []

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(
                    self.device).float(), mask.to(self.device), semi_label.to(
                        self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation on the two criterion
                rec, embed = net(input)
                # reconstruction loss
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1,
                    rec, input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec
                # SVDD embedding loss
                loss_ad = criterion_ad(embed, self.R)
                loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad
                loss = loss_rec + loss_ad

                loss.backward()
                optimizer.step()

                # compute dist to update radius R
                if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up):
                    if self.use_subspace:
                        dist.append(
                            torch.sum((embed - torch.matmul(
                                self.space_repr, embed.transpose(
                                    0, 1)).transpose(0, 1))**2,
                                      dim=1).detach())
                    else:
                        dist.append(
                            torch.sum((self.space_repr - embed)**2,
                                      dim=1).detach())

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            # update radius
            if self.soft_boundary and (epoch + 1 > self.n_epoch_warm_up):
                self.R.data = torch.tensor(self.get_radius(
                    torch.cat(dist, dim=0)),
                                           device=self.device)

            valid_auc = ''
            if valid_dataset:
                auc_rec, auc_ad = self.validate(valid_dataset,
                                                net,
                                                final=False)
                net.train()
                valid_auc = f' Rec AUC: {auc_rec:.3%} | AD AUC: {auc_ad:.3%} | R {self.R:.3f} |'

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | Train Time: {epoch_train_time:.3f} [s] '
                f'| Train Loss: {epoch_loss / n_batch:.6f} |' + valid_auc)

            # append the epoch loss to results list
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'>>> Training of Joint DeepSVDD and AutoEncoder Time: {self.train_time:.3f} [s]'
        )
        logger.info('>>> Finished Joint DeepSVDD and AutoEncoder Training.\n')

        return net
    def pretrain(self, dataset, net):
        """
        Pretrain the AE for the joint DeepSVDD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DeepSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = False

        # define the two criterion for reconstruction
        criterion_rec = MaskedMSELoss()

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Start training
        logger.info('>>> Start Pretraining the Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()

        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch_pretrain):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(
                    self.device).float(), mask.to(self.device), semi_label.to(
                        self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation
                rec, _ = net(input)
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(
                    semi_label.view(-1, 1, 1, 1).expand(*input.shape) != -1,
                    rec, input)
                loss = criterion_rec(rec, input, mask)

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epoch_pretrain:03} | Pretrain Time: {epoch_train_time:.3f} [s] '
                f'| Pretrain Loss: {epoch_loss / n_batch:.6f} |')

            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

        # End training
        self.pretrain_loss = epoch_loss_list
        self.pretrain_time = time.time() - start_time
        logger.info(
            f'>>> Pretraining of AutoEncoder Time: {self.pretrain_time:.3f} [s]'
        )
        logger.info('>>> Finished of AutoEncoder Pretraining.\n')

        return net
Exemplo n.º 26
0
    def train(self, dataset, net):
        """
        Train the DMSVDD on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervized labels.
            |---- net (nn.Module) The DMSVDD to train. The network should be an
            |           autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- net (nn.Module) The pretrained joint DMSVDD.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True # enable the network to provide the SVDD embdeding

        # initialize hypersphere center or subspace projection matrix
        if self.c is None:
            logger.info('>>> Initializing the hyperspheres centers.')
            self.initialize_centers(train_loader, net)
            logger.info(f'>>> {self.n_sphere_init} centers succesfully initialized.')

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss()
        criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary)

        # compute the scale weight so that the rec and svdd losses are scalled and comparable
        self.initialize_loss_scale_weight(train_loader, net, criterion_rec, criterion_ad)

        # define optimizer
        optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the Joint DMSVDD and Autoencoder.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            # update network
            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                # put inputs to device
                input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device)
                input.requires_grad = True

                # mask the input (keep only the object)
                input = input * mask

                # zeros the gradient
                optimizer.zero_grad()

                # Update network parameters by backpropagation on the two criterion
                rec, embed = net(input)
                # reconstruction loss
                # ignore reconstruction for known abnormal samples (no gradient update because loss = 0)
                rec = torch.where(semi_label.view(-1,1,1,1).expand(*input.shape) != -1, rec, input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_rec = self.scale_rec * self.criterion_weight[0] * loss_rec
                # SVDD embedding loss
                loss_ad = criterion_ad(embed, self.c, self.R)
                loss_ad = self.scale_em * self.criterion_weight[1] * loss_ad
                loss = loss_rec + loss_ad

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tWeight-update Batch', Size=20)

            with torch.no_grad():
                # update radius R
                if epoch >= self.n_epoch_warm_up:
                    n_k = torch.zeros(self.c.shape[0], device=self.device)
                    dist = [[] for _ in range(self.c.shape[0])] # list of list for each center : N_center x N_k

                    for b, data in enumerate(train_loader):
                        # compute distance and belonging of sample
                        input, _, mask, semi_label, _ = data
                        input, mask, semi_label = input.to(self.device).float(), mask.to(self.device), semi_label.to(self.device)

                        # mask the input (keep only the object)
                        input = (input * mask)[semi_label != -1]
                        _, embed = net(input)
                        # get closest centers
                        min_dist, idx = torch.min(torch.norm(self.c.unsqueeze(0) - embed.unsqueeze(1), p=2, dim=2), dim=1)
                        for i, d in zip(idx, min_dist):
                            n_k[i] += 1
                            dist[i].append(d)

                        if self.print_batch_progress:
                            print_progessbar(b, n_batch_tot, Name='\t\tRadius-update Batch', Size=20)

                    if self.soft_boundary:
                        # update R with (1-nu)th quantile
                        self.R = torch.where(n_k < self.nu * torch.max(n_k),
                                             torch.Tensor([0.0]).to(self.device),
                                             torch.Tensor([np.quantile(torch.stack(d, dim=0).clone().cpu().numpy(), 1 - self.nu) if len(d) > 0 else 0.0 for d in dist]).to(self.device))

                        # keep only centers and radius where R > 0
                        self.c = self.c[self.R > 0.0]
                        self.R = self.R[self.R > 0.0]
                    else:
                        # keep only centers that are not represented
                        self.c = self.c[n_k == 0] #self.c = self.c[n_k < self.nu * torch.max(n_k)]

            # epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                        f'| Train Time: {epoch_train_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} '
                        f'| N spheres: {self.c.shape[0]:03} |')

            # append the epoch loss to results list
            epoch_loss_list.append([epoch+1, epoch_loss/n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}')

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(f'>>> Training of Joint DMSVDD and AutoEncoder Time: {self.train_time:.3f} [s]')
        logger.info('>>> Finished Joint DMSVDD and AutoEncoder Training.\n')

        return net
    def validate(self, dataset, save_path=None, epoch=0):
        """
        Validate the generator inpainting capabilities on a samll sample of data.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) dataset returning a small sample of fixed pairs (image, mask, idx)
            |               on which to validate the GAN over training. It must return an image tensor of dimension
            |               [C, H, W], a mask tensor of dimension [1, H, W] and a tensor of index of dimension [1].
            |               If None, no validation is performed during training.
            |---- save_path (str) path to directory where to save the inpaint results of the valida_data as .png. Each
            |               image is saved as save_path/valid_imY_epXXX.png where Y is the image index and XXX is the epoch.
            |---- epoch (int) the current epoch number.
        OUTPUT
            |---- l1_loss (float) the mean Discounted L1Loss over the validation images.
        """
        with torch.no_grad():
            # make loader
            valid_loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                worker_init_fn=lambda _: np.random.seed())
            n_batch = len(valid_loader)
            l1_loss_fn = DiscountedL1(gamma=self.gammaL1,
                                      reduction='mean',
                                      device=self.device)

            self.generator.eval()
            # validate data by batch
            l1_loss = 0.0
            idx = 1
            for b, data in enumerate(valid_loader):
                im_v, mask_v = data
                im_v = im_v.to(self.device).float()
                mask_v = mask_v.to(self.device).float()

                # inpaint
                im_inpaint, coarse = self.generator(im_v, mask_v)
                # recover non-masked regions
                im_inpaint = im_v * (1 - mask_v) + im_inpaint * mask_v
                coarse = im_v * (1 - mask_v) + coarse * mask_v
                # compute L1 loss
                l1_loss += l1_loss_fn(im_inpaint, im_v, mask_v).item()
                # save results
                if save_path:
                    for i in range(im_inpaint.shape[0]):
                        arr = im_inpaint[i].permute(1, 2,
                                                    0).squeeze().cpu().numpy()
                        io.imsave(
                            os.path.join(save_path,
                                         f'valid_im{idx}_ep{epoch}.png'),
                            img_as_ubyte(arr))

                        arr = coarse[i].permute(1, 2,
                                                0).squeeze().cpu().numpy()
                        io.imsave(
                            os.path.join(
                                save_path,
                                f'valid_im{idx}_coarse_ep{epoch}.png'),
                            img_as_ubyte(arr))
                        idx += 1

                print_progessbar(b,
                                 n_batch,
                                 Name='Valid Batch',
                                 Size=40,
                                 erase=True)

        return l1_loss / n_batch
    def train(self, dataset, net, valid_dataset=None):
        """
        Train the DeepSAD network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image and
            |           semi-supervized labels.
            |---- net (nn.Module) The DeepSAD to train.
            |---- valid_dataset (torch.utils.data.Dataset) the dataset on which
            |           to validate the network at each epoch. Not validated if
            |           not provided.
        OUTPUT
            |---- net (nn.Module) The trained DeepSAD.
        """
        logger = logging.getLogger()

        # make the train dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_job_dataloader)

        # put net to device
        net = net.to(self.device)

        # initialize hypersphere center
        if self.c is None:
            logger.info(' Initializing the hypersphere center.')
            self.c = self.initialize_hypersphere_center(train_loader, net)
            logger.info(' Center succesfully initialized.')

        # define loss criterion
        loss_fn = DeepSADLoss(self.c, self.eta, eps=self.eps)

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('Start Training the DeepSAD.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch = len(train_loader)

        for epoch in range(self.n_epoch):
            net.train()
            epoch_loss = 0.0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                # get input and semi-supervized labels
                input, _, _, semi_label, _ = data
                # put them to device
                input = input.to(self.device).float().requires_grad_(True)
                semi_label = semi_label.to(self.device)

                # zero the network's gradients
                optimizer.zero_grad()
                # optimize by backpropagation
                _, embed = net(input)
                loss = loss_fn(embed, semi_label)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

                if self.print_batch_progress:
                    print_progessbar(b,
                                     n_batch,
                                     Name='\t\t Train Batch',
                                     Size=40,
                                     erase=True)

            # validate if required
            valid_auc = ''
            if valid_dataset:
                auc = self.evaluate(net,
                                    valid_dataset,
                                    return_auc=True,
                                    print_to_logger=False,
                                    save_tSNE=False)
                valid_auc = f' Valid AUC {auc:.3%} |'

            # log the epoch statistics
            logger.info(
                f'----| Epoch: {epoch + 1:03}/{self.n_epoch:03} '
                f'| Train Time: {time.time() - epoch_start_time:.3f} [s] '
                f'| Train Loss: {epoch_loss / n_batch:.6f} |' + valid_auc)

            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update scheduler
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'---- LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(
            f'---- Finished Training DSAD in {self.train_time:.3f} [s]')

        return net
Exemplo n.º 29
0
    def train(self, dataset, net):
        """
        Train the DROCC network on the provided dataset.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is trained. It must return an image, a mask and
            |           semi-supervised labels.
            |---- net (nn.Module) The DROCC to train. The network should return
            |           the logit of the passed sample.
        OUTPUT
            |---- net (nn.Module) The trained DROCC.
        """
        logger = logging.getLogger()

        # make dataloader
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                                   shuffle=True, num_workers=self.n_jobs_dataloader)
        # put net to device
        net = net.to(self.device)

        # loss function
        criterion = nn.BCEWithLogitsLoss()

        # define optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # define scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestone, gamma=0.1)

        # Start training
        logger.info('>>> Start Training the DROCC.')
        start_time = time.time()
        epoch_loss_list = []
        n_batch_tot = train_loader.__len__()
        # set network in train mode
        net.train()
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0
            n_batch = 0
            epoch_start_time = time.time()

            for b, data in enumerate(train_loader):
                input, _, mask, semi_label, _ = data
                input = input.to(self.device).float()
                mask = mask.to(self.device)
                semi_label = semi_label.to(self.device)
                # get 'label' 0 = normal, 1 = abnormal
                semi_label = torch.where(semi_label != -1,
                                         torch.Tensor([0]).to(self.device),
                                         torch.Tensor([1]).to(self.device))
                # mask the input
                input = input * mask

                if epoch < self.n_epoch_init:
                    # initial steps without adversarial samples
                    input.requires_grad_(True)
                    logit = net(input).squeeze(dim=1)
                    loss = criterion(logit, semi_label)
                else:
                    # get adversarial samples
                    normal_input = input[
                        semi_label ==
                        0]  # get normal input only for the adversarial search
                    adv_input = self.adversarial_search(normal_input, net)

                    # forward on both normal and adversarial samples
                    input.requires_grad_(True)
                    logit = net(input).squeeze(dim=1)
                    logit_adv = net(adv_input).squeeze(dim=1)
                    # loss of samples
                    loss_sample = criterion(logit, semi_label)
                    # loss of adversarial samples
                    loss_adv = criterion(
                        logit_adv,
                        torch.ones(adv_input.shape[0], device=self.device))
                    # weighted sum of normal and aversarial loss
                    loss = loss_sample + self.mu * loss_adv

                # Gradient step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

            # print epoch statistic
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epoch:03} | '
                        f'Train Time: {epoch_train_time:.3f} [s] '
                        f'| Train Loss: {epoch_loss / n_batch:.6f} |')

            # append the epoch loss to results list
            epoch_loss_list.append([epoch + 1, epoch_loss / n_batch])

            # update the learning rate if the milestone is reached
            scheduler.step()
            if epoch + 1 in self.lr_milestone:
                logger.info(
                    f'>>> LR Scheduler : new learning rate {scheduler.get_lr()[0]:g}'
                )

        # End training
        self.train_loss = epoch_loss_list
        self.train_time = time.time() - start_time
        logger.info(f'>>> Training Time of DROCC: {self.train_time:.3f} [s]')
        logger.info('>>> Finished DROCC Training.\n')

        return net
Exemplo n.º 30
0
    def validate(self, dataset, net):
        """
        Validate the joint DMSVDD network on the provided dataset and find the
        best threshold on the score to maximize the f1-score.
        ----------
        INPUT
            |---- dataset (torch.utils.data.Dataset) the dataset on which the
            |           network is validated. It must return an image and
            |           semi-supervized labels.
            |---- net (nn.Module) The DMSVDD to validate. The network should be
            |           an autoencoder for which the forward pass returns both the
            |           reconstruction and the embedding of the input.
        OUTPUT
            |---- None
        """
        logger = logging.getLogger()

        # make test dataloader using image and mask
        valid_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \
                                        shuffle=True, num_workers=self.n_jobs_dataloader)

        # put net to device
        net = net.to(self.device)
        net.return_svdd_embed = True

        # define the two criterion for Anomaly detection and reconstruction
        criterion_rec = MaskedMSELoss(reduction='none')
        criterion_ad = DMSVDDLoss(self.nu, eps=self.eps, soft_boundary=self.soft_boundary)

        # Testing
        logger.info('>>> Start Validating of the joint DMSVDD and AutoEncoder.')
        epoch_loss = 0.0
        n_batch = 0
        n_batch_tot = valid_loader.__len__()
        start_time = time.time()
        idx_label_score_rec = []
        idx_label_score_ad = []
        # put network in evaluation mode
        net.eval()
        with torch.no_grad():
            for b, data in enumerate(valid_loader):
                input, label, mask, semi_label, idx = data
                # put data to device
                input, label = input.to(self.device).float(), label.to(self.device)
                mask, semi_label = mask.to(self.device), semi_label.to(self.device)
                idx = idx.to(self.device)

                # mask the input
                input = input * mask

                # compute loss
                rec, embed = net(input)
                loss_rec = criterion_rec(rec, input, mask)
                loss_ad = criterion_ad(embed, self.c, self.R)
                # compute anomaly scores
                rec_score = torch.mean(loss_rec, dim=tuple(range(1, rec.dim()))) # mean over all dimension per batch

                dist, idx = torch.min(torch.sum((self.c.unsqueeze(0) - embed.unsqueeze(1))**2, dim=2), dim=1) # dist and idx by batch
                if self.soft_boundary:
                    ad_score = dist - torch.stack([self.R[i] ** 2 for i in idx], dim=0) #dist - self.R ** 2 --> negative = normal ; positive = abnormal
                else:
                    ad_score = dist

                # compute overall loss
                mean_loss_rec = torch.sum(loss_rec) / torch.sum(mask)
                loss = self.scale_rec * self.criterion_weight[0] * mean_loss_rec
                loss += self.scale_em * self.criterion_weight[1] * loss_ad

                # append scores and label
                idx_label_score_rec += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            rec_score.cpu().data.numpy().tolist()))
                idx_label_score_ad += list(zip(idx.cpu().data.numpy().tolist(),
                                            label.cpu().data.numpy().tolist(),
                                            ad_score.cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batch += 1

                if self.print_batch_progress:
                    print_progessbar(b, n_batch_tot, Name='\t\tBatch', Size=20)

        self.valid_time = time.time() - start_time
        self.valid_scores_rec = idx_label_score_rec
        _, label, rec_score = zip(*idx_label_score_rec)
        label, rec_score = np.array(label), np.array(rec_score)
        self.valid_auc_rec = roc_auc_score(label, rec_score)
        self.scores_threhold_rec, self.valid_f1_rec = get_best_threshold(rec_score, label, metric=f1_score)

        self.valid_scores_ad = idx_label_score_ad
        _, label, ad_score = zip(*idx_label_score_ad)
        label, ad_score = np.array(label), np.array(ad_score)
        self.valid_auc_ad = roc_auc_score(label, ad_score)
        self.scores_threhold_ad, self.valid_f1_ad = get_best_threshold(ad_score, label, metric=f1_score)

        # add info to logger
        logger.info(f'>>> Validation Time: {self.valid_time:.3f} [s]')
        logger.info(f'>>> Validation Loss: {epoch_loss / n_batch:.6f}')
        logger.info(f'>>> Validation reconstruction AUC: {self.valid_auc_rec:.3%}')
        logger.info(f'>>> Best Threshold for the reconstruction score maximizing F1-score: {self.scores_threhold_rec:.3f}')
        logger.info(f'>>> Best F1-score on reconstruction score: {self.valid_f1_rec:.3%}')
        logger.info(f'>>> Validation DMSVDD AUC: {self.valid_auc_ad:.3%}')
        logger.info(f'>>> Best Threshold for the DMSVDD score maximizing F1-score: {self.scores_threhold_ad:.3f}')
        logger.info(f'>>> Best F1-score on DMSVDD score: {self.valid_f1_ad:.3%}')
        logger.info('>>> Finished validating the Joint DMSVDD and AutoEncoder.\n')