コード例 #1
0
    def pretrain(self, deepSVDD, cfg, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting pretraining...')
        start_time = time.time()
        net.train()
        best_score = 0
        for epoch in range(self.pre_training_epochs):

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()

            for data in tqdm(train_loader):
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()
                # Update network parameters via backpropagation: forward + backward + optimize
                _, rec_images = net(inputs)
                loss = torch.mean(
                    torch.sum(torch.abs(rec_images - inputs),
                              dim=tuple(range(1, rec_images.dim()))))
                loss.backward()
                optimizer.step()

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.pre_training_epochs, epoch_train_time,
                loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #2
0
    def train_one_step(self, net: BaseNet, epoch: int):
        logger = logging.getLogger()

        # Set device for network
        # net = net.to(self.device)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()

        if (True):
            self.scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(self.scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in self.train_loader:
                inputs, _, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                self.optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                dist = torch.sum((outputs - self.c)**2, dim=1)
                if self.objective == 'soft-boundary':
                    scores = dist - self.R**2
                    loss = self.R**2 + (1 / self.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    loss = torch.mean(dist)
                loss.backward()
                self.optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                epoch_loss / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #3
0
ファイル: svm_trainer.py プロジェクト: Csraf/dldm
    def train(self, dataset: BaseADDataset, svm_net: BaseNet):
        """ 训练 svm 模型 """
        logger = logging.getLogger()

        # Set device for networks
        svm_net = svm_net.to(self.device)

        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)
        optimizer = optim.SGD(svm_net.parameters(),
                              lr=self.lr,
                              momentum=self.momentum)
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=self.step_size,
                                              gamma=self.gamma)

        # Training
        logger.info('Starting train svm_trainer ...')
        start_time = time.time()
        svm_net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, labels, _ = data
                inputs = inputs.to(self.device)

                # Zero the networks parameter gradients
                optimizer.zero_grad()

                # Update networks parameters via back propagation: forward + backward + optimize
                outputs = svm_net(inputs)

                # get loss
                loss = self.hinge_loss(outputs, labels)
                loss.backward()
                optimizer.step()

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

        pretrain_time = time.time() - start_time
        logger.info('svm_trainer train time: %.3f' % pretrain_time)
        logger.info('Finished train svm_trainer.')

        return svm_net
コード例 #4
0
    def train(self, dataset: BaseADDataset, ae_net: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        train_loader, _, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Set loss
        criterion = nn.MSELoss(reduction='none')

        # Set device
        #ae_net = ae_net.to(self.device)
        #criterion = criterion.to(self.device)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting pretraining...')
        start_time = time.time()
        ae_net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                rec = ae_net(inputs)
                rec_loss = criterion(rec, inputs)
                loss = torch.mean(rec_loss)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                        f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        logger.info('Pretraining Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished pretraining.')

        return ae_net
コード例 #5
0
    def train(self, dataset: BaseADDataset, ae_net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        ae_net = ae_net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting pretraining...')
        start_time = time.time()
        ae_net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = ae_net(inputs)
                scores = torch.sum((outputs.float() - inputs.float()) ** 2, dim=tuple(range(1, outputs.dim())))
                loss = torch.mean(scores)
                loss.backward()
                optimizer.step()

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'
                        .format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches))

        pretrain_time = time.time() - start_time
        logger.info('Pretraining time: %.3f' % pretrain_time)
        logger.info('Finished pretraining.')

        return ae_net
コード例 #6
0
ファイル: lstm_trainer.py プロジェクト: Csraf/dldm
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for networks
        net = net.to(self.device)

        train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        optimizer = optim.RMSprop(net.parameters(), lr=self.lr, weight_decay=self.weight_decay, eps=self.epsilon,
                                  momentum=self.momentum)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting train lstm_autoencoder ...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the networks parameter gradients
                optimizer.zero_grad()

                # Update networks parameters via back propagation: forward + backward + optimize
                _, outputs = net(inputs.view(-1, 1, self.n_features))
                scores = torch.sum((outputs - inputs) ** 2, dim=tuple(range(1, outputs.dim())))
                loss = torch.mean(scores)
                loss.backward()
                optimizer.step()

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'
                        .format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('lstm_autoencoder train time: %.3f' % self.train_time)
        logger.info('Finished train lstm_autoencoder.')

        return net
コード例 #7
0
    def train_one_step(self, net: BaseNet, epoch: int):

        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()

        epoch_loss = 0.0
        n_batches = 0
        epoch_start_time = time.time()

        for data in self.train_loader:
            inputs, targets, _, _ = data
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # Zero the network parameter gradients
            self.optimizer.zero_grad()

            # Update network parameters via backpropagation: forward + backward + optimize
            outputs = net(inputs)
            targets = targets.type_as(outputs)
            loss = self.criterion(outputs, targets.unsqueeze(1))
            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()
            n_batches += 1
            self.scheduler.step()

        if epoch in self.lr_milestones:
            logger.info('  LR scheduler: new learning rate is %g' %
                        float(scheduler.get_lr()[0]))

            # log epoch statistics
        epoch_train_time = time.time() - epoch_start_time
        logger.info(
            f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
            f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        return {'train_loss': epoch_loss / n_batches}
コード例 #8
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set device for network
        net = net.to(self.device)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, semi_targets, _ = data
                inputs, semi_targets = inputs.to(self.device), semi_targets.to(
                    self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                dist = torch.sum((outputs - self.c)**2, dim=1)
                losses = torch.where(
                    semi_targets == 0, dist,
                    self.eta * ((dist + self.eps)**semi_targets.float()))
                loss = torch.mean(losses)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        return net
コード例 #9
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                # dist = torch.sum((outputs - self.c) ** 2, dim=1)

                ### NEW - get closest cluster center, take dist, sum/mean for loss
                centers = torch.transpose(self.c, 0, 1)
                dist = torch.zeros(outputs.shape[0], device=self.device)
                for i in range(outputs.shape[0]):
                    # Sum dists from each data point to its corresponding cluster
                    dist[i] = torch.sum((centers - outputs[i])**2, dim=1).min()
                #import pdb; pdb.set_trace()
                ###
                if self.objective == 'soft-boundary':
                    scores = dist - self.R**2
                    loss = self.R**2 + (1 / self.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    loss = torch.mean(dist)
                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #10
0
    def train(self, dataset: BaseADDataset, oe_dataset: BaseADDataset,
              net: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        if oe_dataset is not None:
            num_workers = int(self.n_jobs_dataloader / 2)
        else:
            num_workers = self.n_jobs_dataloader

        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=num_workers)
        if oe_dataset is not None:
            if oe_dataset.shuffle:
                if len(dataset.train_set) > len(oe_dataset.train_set):
                    oe_sampler = RandomSampler(oe_dataset.train_set,
                                               replacement=True,
                                               num_samples=len(
                                                   dataset.train_set))
                    oe_loader = DataLoader(dataset=oe_dataset.train_set,
                                           batch_size=self.batch_size,
                                           shuffle=False,
                                           sampler=oe_sampler,
                                           num_workers=num_workers,
                                           drop_last=True)
                else:
                    oe_loader = DataLoader(dataset=oe_dataset.train_set,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=num_workers,
                                           drop_last=True)

            else:
                oe_loader = DataLoader(dataset=oe_dataset.train_set,
                                       batch_size=self.batch_size,
                                       shuffle=False,
                                       num_workers=num_workers,
                                       drop_last=True)
            dataset_loader = zip(train_loader, oe_loader)
        else:
            dataset_loader = train_loader

        # Set loss
        if self.objective in ['bce', 'focal']:
            if self.objective == 'bce':
                criterion = nn.BCEWithLogitsLoss()
            if self.objective == 'focal':
                criterion = FocalLoss(gamma=self.focal_gamma)
            criterion = criterion.to(self.device)

        # Set device
        net = net.to(self.device)

        # Set optimizer
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        net.train()
        start_time = time.time()

        for epoch in range(self.n_epochs + 1):
            epoch_loss = 0.0
            n_batches = 0
            idx_label_score = []
            epoch_start_time = time.time()

            # start at random point for the outlier exposure dataset in each epoch
            if (oe_dataset is not None) and (epoch < self.n_epochs):
                oe_loader.dataset.offset = np.random.randint(
                    len(oe_loader.dataset))
                if oe_loader.dataset.shuffle_idxs:
                    random.shuffle(oe_loader.dataset.idxs)
                dataset_loader = zip(train_loader, oe_loader)

            # only load samples from the original training set in a last epoch for saving train scores
            if epoch == self.n_epochs:
                dataset_loader = train_loader
                net.eval()

            for data in dataset_loader:
                if (oe_dataset is not None) and (epoch < self.n_epochs):
                    inputs = torch.cat((data[0][0], data[1][0]), 0)
                    labels = torch.cat((data[0][1], data[1][1]), 0)
                    semi_targets = torch.cat((data[0][2], data[1][2]), 0)
                    idx = torch.cat((data[0][3], data[1][3]), 0)
                else:
                    inputs, labels, semi_targets, idx = data

                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                semi_targets = semi_targets.to(self.device)
                idx = idx.to(self.device)

                # Zero the network parameter gradients
                if epoch < self.n_epochs:
                    optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)

                if self.objective == 'hsc':
                    if self.hsc_norm == 'l1':
                        dists = torch.norm(outputs, p=1, dim=1)
                    if self.hsc_norm == 'l2':
                        dists = torch.norm(outputs, p=2, dim=1)
                    if self.hsc_norm == 'l2_squared':
                        dists = torch.norm(outputs, p=2, dim=1)**2
                    if self.hsc_norm == 'l2_squared_linear':
                        dists = torch.sqrt(
                            torch.norm(outputs, p=2, dim=1)**2 + 1) - 1

                    scores = 1 - torch.exp(-dists)
                    losses = torch.where(semi_targets == 0, dists,
                                         -torch.log(scores + self.eps))
                    loss = torch.mean(losses)

                if self.objective == 'deepSAD':
                    dists = torch.norm(outputs, p=2, dim=1)**2
                    scores = dists
                    losses = torch.where(
                        semi_targets == 0, dists,
                        ((dists + self.eps)**semi_targets.float()))
                    loss = torch.mean(losses)

                if self.objective in ['bce', 'focal']:
                    targets = torch.zeros(inputs.size(0))
                    targets[semi_targets == -1] = 1
                    targets = targets.view(-1, 1).to(self.device)

                    scores = torch.sigmoid(outputs)
                    loss = criterion(outputs, targets)

                if epoch < self.n_epochs:
                    loss.backward()
                    optimizer.step()

                # save train scores in last epoch
                if epoch == self.n_epochs:
                    idx_label_score += list(
                        zip(idx.cpu().data.numpy().tolist(),
                            labels.cpu().data.numpy().tolist(),
                            scores.flatten().cpu().data.numpy().tolist()))

                epoch_loss += loss.item()
                n_batches += 1

            # Take learning rate scheduler step
            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_last_lr()[0]))

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        self.train_scores = idx_label_score

        # Log results
        logger.info('Train Time: {:.3f}s'.format(self.train_time))
        logger.info('Train Loss: {:.6f}'.format(epoch_loss / n_batches))
        logger.info('Finished training.')

        return net
コード例 #11
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        self.train_loader = train_loader

        # Set device for network
        net = net.to(self.device)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()

            for data in train_loader:
                inputs, _, semi_targets, _ = data
                inputs, semi_targets = inputs.to(self.device), semi_targets.to(
                    self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)

                positive, unlabeled = semi_targets, 1 - semi_targets
                n_positive, n_unlabeled = max([1., torch.sum(positive)]), max(
                    [1., torch.sum(unlabeled)])

                gp = torch.t(torch.log(1 + torch.exp(-outputs)))
                gu = torch.t(torch.log(1 + torch.exp(outputs)))

                loss_positive = self.pi * torch.sum(gp * positive) / n_positive
                loss_negative = torch.sum(
                    gu * unlabeled) / n_unlabeled - self.pi * torch.sum(
                        gu * positive) / n_positive

                loss = loss_positive + loss_negative
                """

                func = torch.t(torch.sigmoid(-outputs))

                loss_positive = -self.pi*torch.sum(func*positive)/n_positive
                loss_negative = torch.sum(1/(1-func*unlabeled))/n_unlabeled - self.pi*torch.sum(1/(1-func*positive))/n_positive

                loss = (loss_positive + loss_negative)**2

                """
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        return net
コード例 #12
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()

        for epoch in range(self.n_epochs):
            net.train()
            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                dist = torch.sum((outputs - self.c)**2, dim=1)
                if self.objective == 'soft-boundary':
                    scores = dist - self.R**2
                    loss = self.R**2 + (1 / self.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    # normalization dist[] to [0 - 1] by min-max scaler
                    dist_min = torch.min(dist)
                    dist_range = torch.max(dist) - dist_min
                    if dist_range > 0:
                        dist_focal = (dist - dist_min) / dist_range
                    else:
                        dist_focal = torch.zeros(dist.size())

                    focal_factor = dist_focal**self.focal_parameter
                    loss = torch.mean(dist * focal_factor)

                    # loss = torch.mean(dist)

                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

            self.Loss_list.append(loss_epoch / n_batches)

            # update hypershpere center c
            if epoch == self.n_epochs - self.update_center_epochs:
                logger.info('Updating center c...')
                self.c = self.init_center_c(train_loader, net)
                logger.info('Center c updated.')

            # testing
            f_get_para = open('../log/mnist_test/get_param.txt', 'a')
            f_get_para.write('Epoch: {:.0f}  -----> '.format(epoch + 1))
            f_get_para.close()

            # record test AUC after each 100 epoch
            if (epoch + 1) % 100 == 0:
                f_100_para = open('../log/mnist_test/100_AUC.txt', 'a')
                f_100_para.write('Epoch: %d ------' % (epoch + 1))
                f_100_para.close()

            self.test(dataset, net.cpu(), epoch)

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #13
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _, _ = dataset.loaders(
            batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)
        # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)
        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        es = EarlyStopping(patience=25)
        if (self.c.dim() == 1):  # naive deep_svdd
            pass
        else:  # multi-center deep_svdd
            cluster_assignments = pickle.load(
                open(os.path.join(self.xp_path, 'cluster_assignments.pkl'),
                     'rb'))
            cluster_numbers = np.unique(list(
                cluster_assignments.values()))  ### TRIPLET
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, idx = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                idx = idx.cpu().numpy()
                ### NEW - get closest cluster center, take dist, sum/mean for loss
                if (self.c.dim() == 1):  # naive deep_svdd
                    centers = self.c
                    dist = torch.sum((outputs - self.c)**2, dim=1)
                else:
                    centers = torch.transpose(self.c, 0, 1)
                    dist = torch.zeros(outputs.shape[0], device=self.device)
                    triplet = torch.zeros(outputs.shape[0],
                                          device=self.device)  ### TRIPLET
                    for i in range(outputs.shape[0]):
                        # Sum dists from each data point to its corresponding cluster
                        # dist[i] = torch.sum((centers - outputs[i]) ** 2, dim=1).min()
                        ### Avoid gradient of min
                        # cluster_idx = torch.sum((centers - outputs[i]) ** 2, dim=1).argmin()
                        # dist[i] = torch.sum((centers[cluster_idx] - outputs[i]) ** 2, dim=0)
                        dist[i] = torch.sum(
                            (centers[cluster_assignments[idx[i]]] -
                             outputs[i])**2,
                            dim=0)
                        # Triplet Loss
                        negative_clusters = np.delete(
                            cluster_numbers, cluster_assignments[idx[i]])
                        neg_example = np.random.choice(negative_clusters)
                        ot = outputs[i].unsqueeze(1)
                        pt = centers[cluster_assignments[idx[i]]].unsqueeze(1)
                        nt = centers[neg_example].unsqueeze(1)
                        triplet[i] = self.triplet_loss(anchor=ot,
                                                       positive=pt,
                                                       negative=nt)
                ###
                if self.objective == 'soft-boundary':
                    scores = dist - self.R**2
                    loss = self.R**2 + (1 / self.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                else:
                    if (self.c.dim() == 1):  # naive deep svdd
                        loss = torch.mean(dist)
                    else:
                        loss = torch.mean(triplet)  ### TRIPLET

                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                loss_epoch += loss.item()
                n_batches += 1

            # # ### DEBUG UMAPs and CENTERS (ADDED EPOCH NUMBER TO LATENT UMAP FUNCTION ###
            # # train_loader, _, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

            # output_data = []
            # label_data = []
            # with torch.no_grad():
            #     for data in train_loader:
            #         # get the inputs of the batch
            #         inputs, labels, _ = data #labels are only for UMAP of hyperspheres
            #         inputs = inputs.to(self.device)
            #         outputs = net(inputs)
            #         output_data.append(outputs)
            #         label_data.append(labels)
            # kmeans_centers = np.load(os.path.join(self.xp_path,'centers.npy'))
            # output_data = torch.cat(output_data)
            # label_data = torch.cat(label_data).numpy()
            # self.latent_UMAP(output_data, label_data, kmeans_centers, pretrain_ae = False, epoch = epoch) ### USE pretrain_ae = False, no repeat
            # ############

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

            # Validation and Early Stopping
            _, val_loader, _ = dataset.loaders(
                batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)
            if val_loader.dataset is None:
                logger.info('Validation data not available...')
                pass
            else:
                net.eval()
                with torch.no_grad():
                    loss_val = 0.0
                    n_batches = 0
                    for data in val_loader:
                        inputs, _, idx = data
                        inputs = inputs.to(self.device)
                        outputs = net(inputs)
                        idx = idx.cpu().numpy()
                        if (self.c.dim() == 1):  # naive deep_svdd
                            centers = self.c
                            dist = torch.sum((outputs - self.c)**2, dim=1)
                        else:
                            centers = torch.transpose(self.c, 0, 1)
                            dist = torch.zeros(outputs.shape[0],
                                               device=self.device)
                            triplet = torch.zeros(
                                outputs.shape[0],
                                device=self.device)  ### TRIPLET
                            for i in range(outputs.shape[0]):
                                # Sum dists from each data point to its corresponding cluster
                                # dist[i] = torch.sum((centers - outputs[i]) ** 2, dim=1).min()
                                ### Avoid gradient of min
                                cluster_idx = torch.sum(
                                    (centers - outputs[i])**2, dim=1).argmin()
                                dist[i] = torch.sum(
                                    (centers[cluster_idx] - outputs[i])**2,
                                    dim=0)
                                # dist[i] = torch.sum((centers[cluster_assignments[idx[i]]] - outputs[i]) ** 2, dim=0) ### THESE are not cluster assignments for validation data (only training data)!
                                # Triplet Loss
                                negative_clusters = np.delete(
                                    cluster_numbers, cluster_idx)
                                neg_example = np.random.choice(
                                    negative_clusters)
                                ot = outputs[i].unsqueeze(1)
                                pt = centers[cluster_idx].unsqueeze(1)
                                nt = centers[neg_example].unsqueeze(1)
                                triplet[i] = self.triplet_loss(anchor=ot,
                                                               positive=pt,
                                                               negative=nt)

                        if self.objective == 'soft-boundary':
                            scores = dist - self.R**2
                            loss = self.R**2 + (1 / self.nu) * torch.mean(
                                torch.max(torch.zeros_like(scores), scores))
                        else:
                            if (self.c.dim() == 1):  # naive deep svdd
                                loss = torch.mean(dist)
                            else:
                                loss = torch.mean(triplet)  ### TRIPLET

                        loss_val += loss.item()
                        n_batches += 1
                logger.info('  Validation Loss: {:.8f}'.format(loss_val /
                                                               n_batches))
                if es.step(torch.tensor(loss_val / n_batches)):
                    logger.info('  Stop early at epoch {}/{}'.format(
                        epoch + 1, self.n_epochs))
                    break

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        # UMAP Plot (on training data)
        # Get train data loader
        train_loader, _, _ = dataset.loaders(
            batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        output_data = []
        label_data = []
        with torch.no_grad():
            for data in train_loader:
                # get the inputs of the batch
                inputs, labels, _ = data  #labels are only for UMAP of hyperspheres
                inputs = inputs.to(self.device)
                outputs = net(inputs)
                output_data.append(outputs)
                label_data.append(labels)
        kmeans_centers = np.load(os.path.join(self.xp_path, 'centers.npy'))
        output_data = torch.cat(output_data)
        label_data = torch.cat(label_data).numpy()
        self.latent_UMAP(
            output_data,
            label_data,
            kmeans_centers,
            pretrain_ae=False,
            epoch=0)  ### USE pretrain_ae = False, no repeat, ignore epoch arg

        return net
コード例 #14
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        # TODO: reinit les poids de lastlay ? Est-ce que cette method est appelee plusieurs fois avec des modeles differents ?
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        optimizerRisk = optim.Adam(self.lastlay.parameters(),
                                   lr=self.lr,
                                   weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        lossfct = UnsupRisk(self.prior0)
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()
                optimizerRisk.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                print("debug batchsize %d" % outputs.size(0))
                scores = self.lastlay(outputs)
                # TODO: handle both choices exact or approx
                loss = lossfct(scores)

                loss.backward()
                optimizer.step()
                optimizerRisk.step()

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #15
0
ファイル: deepSVDD_trainer.py プロジェクト: tianzhaotju/AD
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs, category, resconstruction = net(inputs)
                if self.objective == 'deep-GMM':

                    phi, mu, cov = self.compute_gmm_params(outputs, category)
                    sample_energy, cov_diag = self.compute_energy(outputs, phi=phi, mu=mu, cov=cov,
                                                                size_average=True)

                                                                        # (n,k) --> (n,k,1)
                    # weights = category.unsqueeze(2)
                    #
                    # # (n, k, 1) --> (1, k, 1)
                    # n_k = torch.sum(weights, 0, keepdim=True)
                    #
                    #
                    # # (n,d) ---> (n, k, d)
                    # outputs = outputs.unsqueeze(1).expand(outputs.size(0), self.n_components, outputs.size(1))
                    #
                    # # (n, k, d) --> (1, k, d)
                    # mu = torch.div(torch.sum(weights * outputs, 0, keepdim=True), n_k + self.eps)
                    # var = torch.div(torch.sum(weights * (outputs - mu) * (outputs - mu), 0, keepdim=True), n_k + self.eps)
                    #
                    # self.mu_test = 0.95*self.mu_test + 0.05*mu
                    # self.var_test = 0.95*self.var_test + 0.05*var
                    #
                    # # (1, k, d) --> (n, k, d)
                    # mu = mu.expand(outputs.size(0), self.n_components, self.n_features)
                    # var = var.expand(outputs.size(0), self.n_components, self.n_features)
                    #
                    #
                    #
                    #
                    # # (n, k, d) --> (n, k, 1)
                    # exponent = torch.exp(-.5 * torch.sum((outputs - mu) * (outputs - mu) / var, 2, keepdim=True))
                    # # (n, k, d) --> (n, k, 1)
                    # prefactor = torch.rsqrt(((2. * pi) ** self.n_features) * torch.prod(var, dim=2, keepdim=True) + self.eps)
                    #
                    # # (n, k, 1)
                    # logits_pre = torch.mean(weights, 0, keepdim=True)*prefactor * exponent
                    #
                    # # (n, k, 1) --> (n, k)
                    #
                    # logits_pre = torch.squeeze(logits_pre)


                    #logits = -torch.mean(torch.log(torch.sum(logits_pre, 1) + self.eps))

                    rescon_error = torch.sum((resconstruction - inputs) ** 2, dim=tuple(range(1, resconstruction.dim())))

                    rescon_loss = torch.mean(rescon_error)

                    loss = Variable(sample_energy+rescon_loss, requires_grad= True)


                elif self.objective == 'soft-boundary':
                    dist = torch.sum((outputs - self.c) ** 2, dim=1)
                    scores = dist - self.R ** 2
                    loss = self.R ** 2 + (1 / self.nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
                else:
                    dist = torch.sum((outputs - self.c) ** 2, dim=1)
                    loss = torch.mean(dist)
                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu), device=self.device)

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'
                        .format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #16
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            if 'cw' in self.add_params:
                self.c = self.init_center_c_w(train_loader, net)
            else:
                self.c = self.init_center_c(train_loader, net)
            # self.c_g2 = self.init_center_c_grad(train_loader, net, net.conv2.weight)
            old_mode = self.mode
            if self.mode == 'weight' or self.mode == 'both':
                self.mode = 'weight'
                self.c_g3 = self.init_center_c_grad(train_loader, net, None).detach()
                self.mode = old_mode
            if self.mode == 'input' or self.mode == 'both':
                self.mode = 'input'
                self.c_gi = self.init_center_c_grad(train_loader, net, None).detach()
                self.mode = old_mode
            # self.c_g2 = self.c_g2.detach()
            # self.c_g3 = self.c_g3.detach()
            logger.info('Center c initialized.')

        # Training
        center_update_epochs = 25
        if 'fast_c' in self.add_params:
            center_update_epochs = 5
        logger.info('Starting training...')
        start_time = time.time()
        for epoch in range(self.n_epochs):
            net.train()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_last_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                inputs.requires_grad_(True)
                outputs = net(inputs).squeeze()
                loss3 = None
                old_mode = self.mode
                # grads2 = torch.autograd.grad(outputs=outputs.sum(), inputs=net.conv2.weight, create_graph=True, retain_graph=True)[0]
                if self.mode == 'weight' or self.mode == 'both':
                    self.mode = 'weight'
                    grads3 = torch.autograd.grad(outputs=outputs.sum(), inputs=None, create_graph=True,
                                                 retain_graph=True)[0]
                    dist3 = (grads3 - self.c_g3.expand_as(grads3)) ** 2
                    loss3 = torch.sum(dist3) / outputs.shape[0]
                    self.mode = old_mode
                if self.mode == 'input' or self.mode == 'both':
                    self.mode = 'input'
                    grads3 = \
                    torch.autograd.grad(outputs=outputs.sum(), inputs=inputs, create_graph=True, retain_graph=True)[0]
                    if 'grad_norm' in self.add_params:
                        grads3 = grads3 / (torch.sqrt(
                            torch.sum(grads3 ** 2, dim=tuple(range(1, len(grads3.shape))), keepdim=True)) + 1e-5)
                    dist3 = (grads3 - self.c_gi.expand_as(grads3)) ** 2
                    dist3 = torch.sum(dist3.view(dist3.shape[0], -1), dim=1)
                    if loss3 is None:
                        loss3 = torch.mean(dist3)
                    else:
                        loss3 = loss3 + torch.mean(dist3)
                    self.mode = old_mode
                inputs.requires_grad_(False)
                # if r is None:
                #    r = torch.randn((1,) + grads.shape[1:], device=self.device)
                # print(outputs.shape, self.c.shape, grads.shape, self.c_g.expand_as(grads).shape)
                dist = torch.sum((outputs - self.c) ** 2, dim=1)

                # dist2 = (grads2 - self.c_g2.expand_as(grads2))**2
                # dist = torch.sum((outputs - self.c) ** 2, dim=1)
                if self.objective == 'soft-boundary':
                    scores = dist - self.R ** 2
                    loss = self.R ** 2 + (1 / self.nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
                else:
                    loss = torch.mean(dist)

                    # loss2 = torch.mean(dist2)
                    if loss3 is not None:
                        loss = loss + loss3
                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu), device=self.device)

                loss_epoch += loss.item()
                n_batches += 1
            scheduler.step()

            #             if 'update_c' in self.add_params and epoch % center_update_epochs == 0:
            #                 logger.info('Updating center c...')
            #                 self.c = self.init_center_c(train_loader, net)
            #                 #self.c_g2 = self.init_center_c_grad(train_loader, net, net.conv2.weight)
            #                 old_mode = self.mode
            #                 if self.mode == 'weight' or self.mode == 'both':
            #                     self.mode = 'weight'
            #                     self.c_g3 = self.init_center_c_grad(train_loader, net, layer.weight).detach()
            #                     self.mode = old_mode
            #                 if self.mode == 'input' or self.mode == 'both':
            #                     self.mode = 'input'
            #                     self.c_gi = self.init_center_c_grad(train_loader, net, layer.weight).detach()
            #                     self.mode = old_mode
            #                 logger.info('Center c updated.')

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'
                        .format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches))

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #17
0
    def train(self, dataset: BaseADDataset, net: BaseNet):

        logger = logging.getLogger()

        if self.train_loader is None:
            try:
                self.train_loader, _, _ = dataset.loaders(
                    batch_size=self.batch_size,
                    num_workers=self.n_jobs_dataloader)
            except:
                self.train_loader, _ = dataset.loaders(
                    batch_size=self.batch_size,
                    num_workers=self.n_jobs_dataloader)

        # Set device for network
        net = net.to(self.device)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Set loss function
        criterion = BCEWithLogitsLoss()

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()

        for epoch in range(self.n_epochs):
            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()

            for data in train_loader:
                inputs, targets, _, _ = data
                inputs, targets = inputs.to(self.device), targets.to(
                    self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                outputs = net(inputs)
                targets = targets.type_as(outputs)
                loss = criterion(outputs, targets.unsqueeze(1))
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

            if self.reporter:
                self._log_train(net, dataset)
                self.reporter(
                    **
                    {'train/loss/' + str(dataset.id): epoch_loss / n_batches})

        self.train_time = time.time() - start_time
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        return net
コード例 #18
0
    def train(self, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set device
        net = net.to(self.device)

        # Use importance weighted sampler (Burda et al., 2015) to get a better estimate on the log-likelihood.
        sampler = ImportanceWeightedSampler(mc=1, iw=1)
        elbo = SVI(net, likelihood=binary_cross_entropy, sampler=sampler)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, labels, semi_targets, _ = data

                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                semi_targets = semi_targets.to(self.device)

                # Get labeled and unlabeled data and make labels one-hot
                inputs = inputs.view(inputs.size(0), -1)
                x = inputs[semi_targets != 0]
                u = inputs[semi_targets == 0]
                y = labels[semi_targets != 0]
                if y.nelement() > 1:
                    y_onehot = torch.Tensor(y.size(0), 2).to(
                        self.device)  # two labels: 0: normal, 1: outlier
                    y_onehot.zero_()
                    y_onehot.scatter_(1, y.view(-1, 1), 1)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                if y.nelement() < 2:
                    L = torch.tensor(0.0).to(self.device)
                else:
                    L = -elbo(x, y_onehot)
                U = -elbo(u)

                # Regular cross entropy
                if y.nelement() < 2:
                    classication_loss = torch.tensor(0.0).to(self.device)
                else:
                    # Add auxiliary classification loss q(y|x)
                    logits = net.classify(x)
                    eps = 1e-8
                    classication_loss = torch.sum(y_onehot *
                                                  torch.log(logits + eps),
                                                  dim=1).mean()

                # Overall loss
                loss = L - self.alpha * classication_loss + U  # J_alpha

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        return net
コード例 #19
0
    def train(self, deepSVDD, cfg, dataset: BaseADDataset, net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(net.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        best_score = 0

        if self.c is None:
            self.c = self.init_center_c(train_loader, net)

        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            loss_epoch = 0.0
            n_batches = 0
            epoch_start_time = time.time()

            for data in tqdm(train_loader):
                inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()
                # Update network parameters via backpropagation: forward + backward + optimize
                features, rec_images = net(inputs)
                dist = torch.sum((features - self.c)**2, dim=1)
                rec_loss = torch.mean(
                    torch.sum(torch.abs(rec_images - inputs),
                              dim=tuple(range(1, rec_images.dim()))))
                if self.objective == 'soft-boundary':
                    scores = dist - self.R**2
                    loss = cfg.settings['w_svdd'] * (
                        self.R**2 + (1 / self.nu) *
                        torch.mean(torch.max(torch.zeros_like(scores), scores))
                    ) + cfg.settings['w_rec'] * rec_loss
                else:
                    loss = cfg.settings['w_svdd'] * torch.mean(
                        dist) + cfg.settings['w_rec'] * rec_loss
                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                loss_epoch += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}'.format(
                epoch + 1, self.n_epochs, epoch_train_time,
                loss_epoch / n_batches))

            # Test model
            deepSVDD.test(dataset,
                          device=self.device,
                          n_jobs_dataloader=self.n_jobs_dataloader)
            if self.test_auc > best_score:
                # Save results, model, and configuration
                best_score = self.test_auc
                deepSVDD.R = float(self.R.cpu().data.numpy())  # get float
                deepSVDD.c = self.c.cpu().data.numpy().tolist()  # get list
                deepSVDD.save_results(export_json=cfg.settings['xp_path'] +
                                      '/results.json')
                deepSVDD.save_model(export_model=cfg.settings['xp_path'] +
                                    '/model.tar')
                cfg.save_config(export_json=cfg.settings['xp_path'] +
                                '/config.json')

                if cfg.settings['dataset_name'] in ('mnist', 'cifar10'):

                    # Plot most anomalous and most normal (within-class) test samples
                    indices, labels, scores = zip(
                        *deepSVDD.results['test_scores'])
                    indices, labels, scores = np.array(indices), np.array(
                        labels), np.array(scores)
                    idx_sorted = indices[labels == 0][np.argsort(
                        scores[labels == 0]
                    )]  # sorted from lowest to highest anomaly score

                    if cfg.settings['dataset_name'] == 'mnist':
                        X_normals = dataset.test_set.data[idx_sorted[:32],
                                                          ...].unsqueeze(1)
                        X_outliers = dataset.test_set.data[idx_sorted[-32:],
                                                           ...].unsqueeze(1)

                    if cfg.settings['dataset_name'] == 'cifar10':
                        X_normals = torch.tensor(
                            np.transpose(
                                dataset.test_set.data[idx_sorted[:32], ...],
                                (0, 3, 1, 2)))
                        X_outliers = torch.tensor(
                            np.transpose(
                                dataset.test_set.data[idx_sorted[-32:], ...],
                                (0, 3, 1, 2)))

                    plot_images_grid(X_normals,
                                     export_img=cfg.settings['xp_path'] +
                                     '/normals',
                                     title='Most normal examples',
                                     padding=2)
                    plot_images_grid(X_outliers,
                                     export_img=cfg.settings['xp_path'] +
                                     '/outliers',
                                     title='Most anomalous examples',
                                     padding=2)

        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #20
0
    def train(self, dataset: BaseADDataset, net: BaseNet):

        # #?????
        # for k,v in net.named_parameters():
        #     if k!= 'dense1.weight' and k!= 'dense1.bias':
        #         v.requires_grad = False
        # #?????

        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set optimizer (Adam optimizer for now)
        #optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay,amsgrad=self.optimizer_name == 'amsgrad')
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      net.parameters()),
                               lr=self.lr,
                               weight_decay=self.weight_decay,
                               amsgrad=self.optimizer_name == 'amsgrad')

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Initialize hypersphere center c (if c not loaded)
        if self.c is None:
            logger.info('Initializing center c...')
            self.c = self.init_center_c(train_loader, net)
            logger.info('Center c initialized.')

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        ssim_loss = pytorch_ssim.SSIM(window_size=11)
        for epoch in range(self.n_epochs):

            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))
            loss_rec = 0.0
            loss_epoch = 0.0
            loss_diag = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                if self.dataset_name == 'object' or self.dataset_name == 'texture':
                    inputs, _, _ = data
                else:
                    inputs, _, _ = data
                inputs = inputs.to(self.device)

                # Zero the network parameter gradients
                optimizer.zero_grad()
                sample_energy = 0
                cov_diag = 0
                # Update network parameters via backpropagation: forward + backward + optimize
                outputs, category, resconstruction = net(inputs)
                if self.objective == 'deep-GMM':

                    phi, mu, cov = self.compute_gmm_params(outputs, category)

                    sample_energy, cov_diag = self.compute_energy(
                        outputs, phi=phi, mu=mu, cov=cov, size_average=True)
                    if self.ae_loss_type == 'ssim':
                        rescon_error = -ssim_loss(inputs, resconstruction)
                        rescon_error = self.ssim_lambda * rescon_error
                    else:
                        rescon_error = torch.sum(
                            (resconstruction - inputs)**2,
                            dim=tuple(range(1, resconstruction.dim())))
                        rescon_error = self.l2_lambda * rescon_error
                    rescon_loss = torch.mean(rescon_error)

                    loss = sample_energy + rescon_loss + self.cov_diag_lambda * cov_diag

                elif self.objective == 'soft-boundary':
                    dist = torch.sum((outputs - self.c)**2, dim=1)
                    scores = dist - self.R**2
                    loss = self.R**2 + (1 / self.nu) * torch.mean(
                        torch.max(torch.zeros_like(scores), scores))
                elif self.objective == 'hybrid':
                    dist = torch.sum((outputs - self.c)**2, dim=1)
                    if self.ae_loss_type == 'ssim':
                        rescon_error = -ssim_loss(inputs, resconstruction)
                        rescon_error = self.ssim_lambda * rescon_error
                    else:
                        rescon_error = torch.sum(
                            (resconstruction - inputs)**2,
                            dim=tuple(range(1, resconstruction.dim())))
                        rescon_error = self.l2_lambda * rescon_error
                    rescon_loss = torch.mean(rescon_error)
                    dist_ave = torch.mean(dist)
                    dist = rescon_error + dist
                    loss = torch.mean(dist)

                else:
                    if self.ae_loss_type == 'ssim':
                        rescon_error = -ssim_loss(inputs, resconstruction)
                        rescon_error = self.ssim_lambda * rescon_error
                    else:
                        rescon_error = torch.sum(
                            (resconstruction - inputs)**2,
                            dim=tuple(range(1, resconstruction.dim())))
                        rescon_error = self.l2_lambda * rescon_error
                    rescon_loss = torch.mean(rescon_error)
                    dist = torch.sum((outputs - self.c)**2, dim=1)
                    dist_ave = torch.mean(dist)
                    loss = torch.mean(dist)

                loss.backward()
                optimizer.step()

                # Update hypersphere radius R on mini-batch distances
                if (self.objective == 'soft-boundary') and (
                        epoch >= self.warm_up_n_epochs):
                    self.R.data = torch.tensor(get_radius(dist, self.nu),
                                               device=self.device)

                if self.objective == 'deep-GMM':
                    loss_epoch += sample_energy.item()
                else:
                    loss_epoch += dist_ave.item()
                #loss_diag += self.cov_diag_lambda*cov_diag.item()
                loss_rec += rescon_loss.item()

                n_batches += 1
            scheduler.step()

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            #logger.info('  Epoch {}/{}\t Time: {:.3f}\t Energy: {:.8f}, Cov_diag:  {:.8f},  Reconstrcion:  {:.8f}'.format(epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / n_batches, loss_diag, loss_rec/ n_batches))
            logger.info(
                '  Epoch {}/{}\t Time: {:.3f}\t Energy: {:.8f}, Reconstrcion:  {:.8f}'
                .format(epoch + 1, self.n_epochs, epoch_train_time,
                        loss_epoch / n_batches, loss_rec / n_batches))
            self.l2_lambda_test = loss_epoch / loss_rec
        self.train_time = time.time() - start_time
        logger.info('Training time: %.3f' % self.train_time)

        logger.info('Finished training.')

        return net
コード例 #21
0
def train(dataset: BaseADDataset, ae_net: BaseNet):
    # Set device for network
    ae_net = ae_net.to(device)

    # Get train data loader
    letter, labels = dataset.loaders(batch_size=batch_size,
                                     num_workers=0,
                                     shuffle_test=False,
                                     shuffle_train=False)

    # Set optimizer (Adam optimizer for now)
    optimizer = optim.Adam(ae_net.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)

    # Training
    start_time = time.time()
    ae_net.train()
    for epoch in range(n_epochs):
        loss_epoch = 0.0
        n_batches = 0
        epoch_start_time = time.time()
        for data, label in zip(letter, labels):
            inputs, _ = data
            lab, _ = label
            inputs = inputs.to(device)
            lab = lab.to(device)
            # Zero the network parameter gradients
            optimizer.zero_grad()
            outputs = ae_net(inputs)
            scores = torch.sum((outputs - lab)**2,
                               dim=tuple(range(1, outputs.dim())))
            loss = torch.mean(scores)
            loss.backward()
            optimizer.step()

            loss_epoch += loss.item()
            n_batches += 1
        epoch_end_time = time.time()
        print('Epoch: ' + str(epoch + 1) + '/' + str(n_epochs) + ' time: ' +
              str(epoch_end_time - epoch_start_time) + ' loss: ' +
              str(loss_epoch / n_batches))

    with torch.no_grad():
        plot_images_grid(inputs,
                         export_img='./log/train/input',
                         title='Input ',
                         nrow=4,
                         padding=4)
        plot_images_grid(lab,
                         export_img='./log/train/labbel',
                         title='Label ',
                         nrow=4,
                         padding=4)
        plot_images_grid(outputs,
                         export_img='./log/train/output',
                         title='Output ',
                         nrow=4,
                         padding=4)

    return ae_net
コード例 #22
0
ファイル: vae_trainer.py プロジェクト: MasaKat0/D3RE
    def train(self, dataset: BaseADDataset, vae: BaseNet):
        logger = logging.getLogger()

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Set device
        vae = vae.to(self.device)

        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(vae.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting pretraining...')
        start_time = time.time()
        vae.train()
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            epoch_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()
            for data in train_loader:
                inputs, _, _, _ = data
                inputs = inputs.to(self.device)
                inputs = inputs.view(inputs.size(0), -1)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize
                rec = vae(inputs)

                likelihood = -binary_cross_entropy(rec, inputs)
                elbo = likelihood - vae.kl_divergence

                # Overall loss
                loss = -torch.mean(elbo)

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

        self.train_time = time.time() - start_time
        logger.info('Pretraining Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished pretraining.')

        return vae