Ejemplo n.º 1
0
class Ganomaly(object):
    """GANomaly Class
    """
    @staticmethod
    def name():
        """Return name of the class.
        """
        return 'Ganomaly'

    def __init__(self, opt, dataloader=None):
        super(Ganomaly, self).__init__()
        ##
        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device(
            "cuda:0" if self.opt.device != 'cpu' else "cpu")

        # -- Discriminator attributes.
        self.out_d_real = None
        self.feat_real = None
        self.err_d_real = None
        self.fake = None
        self.latent_i = None
        self.latent_o = None
        self.out_d_fake = None
        self.feat_fake = None
        self.err_d_fake = None
        self.err_d = None

        # -- Generator attributes.
        self.out_g = None
        self.err_g_bce = None
        self.err_g_l1l = None
        self.err_g_enc = None
        self.err_g = None

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks.
        self.netg = NetG(self.opt).to(self.device)
        self.netd = NetD(self.opt).to(self.device)
        self.netg.apply(weights_init)
        self.netd.apply(weights_init)

        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")

        # print(self.netg)
        # print(self.netd)

        ##
        # Loss Functions
        self.bce_criterion = nn.BCELoss()
        self.l1l_criterion = nn.L1Loss()
        self.l2l_criterion = l2_loss

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = 1
        self.fake_label = 0

        ##
        # Setup optimizer
        if self.opt.isTrain:
            self.netg.train()
            self.netd.train()
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))

    ##
    def set_input(self, input):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])

        # Copy the first batch as the fixed input.
        if self.total_steps == self.opt.batchsize:
            with torch.no_grad():
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def update_netd(self):
        """
        Update D network: Ladv = |f(real) - f(fake)|_2
        """
        ##
        # Feature Matching.
        self.netd.zero_grad()
        # --
        # Train with real
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.real_label)
        self.out_d_real, self.feat_real = self.netd(self.input)
        # --
        # Train with fake
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.fake, self.latent_i, self.latent_o = self.netg(self.input)
        self.out_d_fake, self.feat_fake = self.netd(self.fake.detach())
        # --
        self.err_d = l2_loss(self.feat_real, self.feat_fake)
        self.err_d_real = self.err_d
        self.err_d_fake = self.err_d
        self.err_d.backward()
        self.optimizer_d.step()

    ##
    def reinitialize_netd(self):
        """ Initialize the weights of netD
        """
        self.netd.apply(weights_init)
        print('Reloading d net')

    ##
    def update_netg(self):
        """
        # ============================================================ #
        # (2) Update G network: log(D(G(x)))  + ||G(x) - x||           #
        # ============================================================ #

        """
        self.netg.zero_grad()
        with torch.no_grad():
            self.label.resize_(self.opt.batchsize).fill_(self.real_label)
        self.out_g, _ = self.netd(self.fake)

        self.err_g_bce = self.bce_criterion(self.out_g, self.label)
        self.err_g_l1l = self.l1l_criterion(
            self.fake, self.input)  # constrain x' to look like x
        self.err_g_enc = self.l2l_criterion(self.latent_o, self.latent_i)
        self.err_g = self.err_g_bce * self.opt.w_bce + self.err_g_l1l * self.opt.w_rec + self.err_g_enc * self.opt.w_enc

        self.err_g.backward(retain_graph=True)
        self.optimizer_g.step()

    ##
    def optimize(self):
        """ Optimize netD and netG  networks.
        """

        self.update_netd()
        self.update_netg()

        # If D loss is zero, then re-initialize netD
        if self.err_d_real.item() < 1e-5 or self.err_d_fake.item() < 1e-5:
            self.reinitialize_netd()

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([('err_d', self.err_d.item()),
                              ('err_g', self.err_g.item()),
                              ('err_d_real', self.err_d_real.item()),
                              ('err_d_fake', self.err_d_fake.item()),
                              ('err_g_bce', self.err_g_bce.item()),
                              ('err_g_l1l', self.err_g_l1l.item()),
                              ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train',
                                  'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netg.state_dict()
        }, '%s/netG.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netd.state_dict()
        }, '%s/netD.pth' % (weight_dir))

    ##
    def train_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'],
                         leave=False,
                         total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            self.optimize()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(
                        self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(
                        self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes,
                                                    fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name(), self.epoch + 1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name())
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name())

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
Ejemplo n.º 2
0
class BaseModel():
    """ Base Model for ganomaly
    """
    def __init__(self, opt, dataloader):
        ##
        # Seed for deterministic behavior
        self.seed(opt.manualseed)

        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu")

    ##
    def set_input(self, input:torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def seed(self, seed_value):
        """ Seed 
        
        Arguments:
            seed_value {int} -- [description]
        """
        # Check if seed is default value
        if seed_value == -1:
            return

        # Otherwise seed all functionality
        import random
        random.seed(seed_value)
        torch.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        np.random.seed(seed_value)
        torch.backends.cudnn.deterministic = True

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([
            ('err_d', self.err_d.item()),
            ('err_g', self.err_g.item()),
            ('err_g_adv', self.err_g_adv.item()),
            ('err_g_con', self.err_g_con.item()),
            ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()},
                   '%s/netG.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()},
                   '%s/netD.pth' % (weight_dir))

    ##
    def train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            # self.optimize()
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name)
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_one_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device)
            self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long,    device=self.device)
            self.latent_i  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)
            self.latent_o  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz)
                self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True)
                    vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio, performance)
            return performance
Ejemplo n.º 3
0
class BaseModel():
    """ Base Model for ganomaly
    """
    def __init__(self, opt, dataloader):
        ##
        # Seed for deterministic behavior
        self.seed(opt.manualseed)

        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device(
            "cuda:0" if self.opt.device != 'cpu' else "cpu")

    ##
    def set_input(self, input: torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            #self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            if self.opt.dataset == 'mnist':
                abn_idx = torch.from_numpy((np.where(
                    input[1].numpy() == int(self.opt.abnormal_class)))[0])

            elif self.opt.dataset == 'fashionmnist':
                classes = {
                    'tshirt': 0,
                    'trouser': 1,
                    'pullover': 2,
                    'dress': 3,
                    'coat': 4,
                    'sandal': 5,
                    'shirt': 6,
                    'sneacker': 7,
                    'bag': 8,
                    'boot': 9
                }
                abn_idx = torch.from_numpy((np.where(
                    input[1].numpy() == int(classes[self.opt.abnormal_class]))
                                            )[0])

            elif self.opt.dataset == 'svhn':
                abn_idx = torch.from_numpy((np.where(
                    input[1].numpy() == int(self.opt.abnormal_class)))[0])

            else:
                classes = {
                    'plane': 0,
                    'car': 1,
                    'bird': 2,
                    'cat': 3,
                    'deer': 4,
                    'dog': 5,
                    'frog': 6,
                    'horse': 7,
                    'ship': 8,
                    'truck': 9
                }
                abn_idx = torch.from_numpy((np.where(
                    input[1].numpy() == int(classes[self.opt.abnormal_class]))
                                            )[0])
            test = torch.zeros(input[1].size())
            test[abn_idx] = 1
            self.gt = test

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def seed(self, seed_value):
        """ Seed 
        
        Arguments:
            seed_value {int} -- [description]
        """
        # Check if seed is default value
        if seed_value == -1:
            return

        # Otherwise seed all functionality
        import random
        random.seed(seed_value)
        torch.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        np.random.seed(seed_value)
        torch.backends.cudnn.deterministic = True

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([('err_d', self.err_d.item()),
                              ('err_g', self.err_g.item()),
                              ('err_g_adv', self.err_g_adv.item()),
                              ('err_g_con', self.err_g_con.item()),
                              ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        #weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights')
        #if not os.path.exists(weight_dir): os.makedirs(weight_dir)
        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'test',
                                  'abnormal' + str(self.opt.abnormal_class),
                                  'seed' + str(self.opt.manualseed))

        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netg.state_dict()
        }, '%s/netG.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netd.state_dict()
        }, '%s/netD.pth' % (weight_dir))

    ##
    def train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'],
                         leave=False,
                         total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize
            self.set_input(data)
            # self.optimize()
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(
                        self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(
                        self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes,
                                                    fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name, self.epoch + 1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name)
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_one_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
            #     self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)
        self.save_weights(self.epoch)

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                #self.fake, latent_i, latent_o = self.netg(self.input)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance

    def train_final(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            #self.netg.eval()
            #self.dataloader=dataL
            #self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['train'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['train'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['train'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['train'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            for i, data in enumerate(self.dataloader['train'], 0):

                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)
                self.times.append(time_o - time_i)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))

            recon_1 = self.latent_i - self.latent_o

            return recon_1

    def test_final(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            #self.dataloader=dataL
            #self.netg.eval()
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.

            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.gt_labels_original = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                                  dtype=torch.long,
                                                  device=self.device)

            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            #del self.netg.grad_dict['final-256-1-conv']['grad_output']
            #self.netg.grad_dict['final-256-1-conv']['grad_output']=[]
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                #self.fake, latent_i, latent_o = self.netg(self.input)
                #/#self.fake, latent_i, latent_o, latent_o_p, latent_o_pp, latent_o_ppp = self.netg(self.input)
                #self.fake, latent_i, latent_o = self.netg(self.input)
                self.fake, latent_i, latent_o = self.netg(self.input)
                if (i == 0):
                    test_path = os.path.join(
                        self.opt.outf, self.opt.dataset, 'test', 'OCSVM',
                        'abnormal' + str(self.opt.abnormal_class))
                    print(test_path)
                    if not os.path.isdir(test_path):
                        os.makedirs(test_path)
                    test_path = os.path.join(
                        self.opt.outf, self.opt.dataset, 'test', 'OCSVM',
                        'abnormal' + str(self.opt.abnormal_class),
                        'reconstructed' + str(self.opt.abnormal_class) +
                        '.png')

                    viz.viz_batch_im(np.squeeze(np.array(self.fake.cpu())),
                                     grid_size=[7, 7],
                                     save_path=test_path,
                                     gap=0,
                                     gap_color=0,
                                     shuffle=False)

                    test_path = os.path.join(
                        self.opt.outf, self.opt.dataset, 'test', 'OCSVM',
                        'abnormal' + str(self.opt.abnormal_class),
                        'img' + str(self.opt.abnormal_class) + '.png')

                    viz.viz_batch_im(np.squeeze(np.array(self.input.cpu())),
                                     grid_size=[7, 7],
                                     save_path=test_path,
                                     gap=0,
                                     gap_color=0,
                                     shuffle=False)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.gt_labels_original[i * self.opt.batchsize:i *
                                        self.opt.batchsize +
                                        error.size(0)] = self.label.reshape(
                                            error.size(0))

                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))

            auroc_value, auprc_value, _ = evaluate_final(
                self.gt_labels.cpu().detach(),
                self.an_scores.cpu().detach())
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUROC', auroc_value),
                                       ('AUPRC', auprc_value)])
            #self.visualizer.print_final_performance(performance)

            recon_1 = self.latent_i - self.latent_o

            y_true = self.gt_labels
            y_true_original = self.gt_labels_original

            return recon_1, y_true, y_true_original, auroc_value, auprc_value
Ejemplo n.º 4
0
class BaseModel():
    """ Base Model for ganomaly
    """
    def __init__(self, opt, data):
        ##
        # Seed for deterministic behavior
        self.seed(opt.manualseed)

        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.data = data
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu")

    ##
    def seed(self, seed_value):
        """ Seed 

        Arguments:
            seed_value {int} -- [description]
        """
        # Check if seed is default value
        if seed_value == -1:
            return

        # Otherwise seed all functionality
        import random
        random.seed(seed_value)
        torch.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        np.random.seed(seed_value)
        torch.backends.cudnn.deterministic = True

    ##
    def set_input(self, input:torch.Tensor, noise:bool=False):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Add noise to the input.
            if noise: self.noise.data.copy_(torch.randn(self.noise.size()))

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """
        errors = OrderedDict([
            ('err_d', self.err_d.item()),
            ('err_g', self.err_g.item()),
            ('err_g_adv', self.err_g_adv.item()),
            ('err_g_con', self.err_g_con.item()),
            ('err_g_lat', self.err_g_lat.item())])

        return errors

    ##
    def reinit_d(self):
        """ Initialize the weights of netD
        """
        self.netd.apply(weights_init)
        print('Reloading d net')

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch:int, is_best:bool=False):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(
            self.opt.outf, self.opt.name, 'train', 'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        if is_best:
            torch.save({'epoch': epoch, 'state_dict': self.netg.state_dict()}, f'{weight_dir}/netG_best.pth')
            torch.save({'epoch': epoch, 'state_dict': self.netd.state_dict()}, f'{weight_dir}/netD_best.pth')
        else:
            torch.save({'epoch': epoch, 'state_dict': self.netd.state_dict()}, f"{weight_dir}/netD_{epoch}.pth")
            torch.save({'epoch': epoch, 'state_dict': self.netg.state_dict()}, f"{weight_dir}/netG_{epoch}.pth")

    def load_weights(self, epoch=None, is_best:bool=False, path=None):
        """ Load pre-trained weights of NetG and NetD

        Keyword Arguments:
            epoch {int}     -- Epoch to be loaded  (default: {None})
            is_best {bool}  -- Load the best epoch (default: {False})
            path {str}      -- Path to weight file (default: {None})

        Raises:
            Exception -- [description]
            IOError -- [description]
        """

        if epoch is None and is_best is False:
            raise Exception('Please provide epoch to be loaded or choose the best epoch.')

        if is_best:
            fname_g = f"netG_best.pth"
            fname_d = f"netD_best.pth"
        else:
            fname_g = f"netG_{epoch}.pth"
            fname_d = f"netD_{epoch}.pth"

        if path is None:
            path_g = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_g}"
            path_d = f"./output/{self.name}/{self.opt.dataset}/train/weights/{fname_d}"

        # Load the weights of netg and netd.
        print('>> Loading weights...')
        weights_g = torch.load(path_g)['state_dict']
        weights_d = torch.load(path_d)['state_dict']
        try:
            self.netg.load_state_dict(weights_g)
            self.netd.load_state_dict(weights_d)
        except IOError:
            raise IOError("netG weights not found")
        print('   Done.')

    ##
    def train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        for data in tqdm(self.data.train, leave=False, total=len(self.data.train)):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(self.data.train.dataset)
                    self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter))

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(f">> Training {self.name} on {self.opt.dataset} to detect {self.opt.abnormal_class}")
        for self.epoch in range(self.opt.iter, self.opt.niter):
            self.train_one_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            data ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(self.data.valid.dataset),), dtype=torch.float32, device=self.device)
            self.gt_labels = torch.zeros(size=(len(self.data.valid.dataset),), dtype=torch.long,    device=self.device)
            self.latent_i  = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device)
            self.latent_o  = torch.zeros(size=(len(self.data.valid.dataset), self.opt.nz), dtype=torch.float32, device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.data.valid, 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz)
                self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                        real, fake, _ = self.get_current_images()
                        vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True)
                        vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True)

                        # Measure inference time.
                        self.times = np.array(self.times)
                        self.times = np.mean(self.times[:100] * 1000)

                        # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores))
            auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(self.data.valid.dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio, performance)
            return performance

        ##
        def update_learning_rate(self):
            """ Update learning rate based on the rule provided in options.
            """
    
            for scheduler in self.schedulers:
                scheduler.step()
            lr = self.optimizers[0].param_groups[0]['lr']
            print('   LR = %.7f' % lr)        
Ejemplo n.º 5
0
class ocgan(object):
    """GANomaly Class
    """

    @staticmethod
    def name():
        """Return name of the class.
        """
        return 'ocgan'

    def __init__(self, opt, dataloader=None):
        super(ocgan, self).__init__()
        ##
        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')

        # -- Discriminator attributes.
        self.out_dv0=None
        self.out_dv1=None
        self.out_dv=None
        self.label_dv=None
        self.err_dv_bce=None
        self.out_dl0=None
        self.out_dl1=None
        self.out_dl=None
        self.label_dl=None
        self.err_dl_bce=None
        self.err_d=None

        # -- Generator attributes.
        self.out_gv0 = None
        self.out_gv1 = None
        self.out_gv = None
        self.label_gv = None
        self.err_gv_bce = None
        self.out_gl0 = None
        self.out_gl1 = None
        self.out_gl = None
        self.label_gl = None
        self.err_gl_bce = None
        self.err_g_mse =None
        self.err_g = None

        # -- Classfier attribute
        self.out_c0=None
        self.out_c1=None
        self.out_c=None
        self.label_c=None
        self.err_c_bce = None

        # -- mine attribute
        self.out_m = None
        self.err_m_bce = None

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0


        ##
        # Create and initialize networks.
        print('Create and initialize networks.')
        self.neten=Encoder(self.opt.isize,self.opt.nc,self.opt.ndf)
        self.netde=Decoder(self.opt.isize,self.opt.nc,self.opt.ngf)
        self.netdl=Dl(self.opt.nz)
        self.netdv=Dv(self.opt)
        self.netc=Classfier(self.opt)
        self.netdv.apply(weights_init)
        self.netc.apply(weights_init)
        self.neten.apply(weights_init)
        self.netde.apply(weights_init)
        print('end')
        ##
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(os.path.join(self.opt.resume, 'neten.pth'))['epoch']
            self.neten.load_state_dict(torch.load(os.path.join(self.opt.resume, 'neten.pth'))['state_dict'])
            self.netde.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netde.pth'))['state_dict'])
            self.netdl.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netdl.pth'))['state_dict'])
            self.netdv.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netdv.pth'))['state_dict'])
            self.netc.load_state_dict(torch.load(os.path.join(self.opt.resume, 'netC.pth'))['state_dict'])
            print("\tDone.\n")

        print(self.neten)
        print(self.netde)
        print(self.netdl)
        print(self.netdv)
        print(self.netc)
        ##
        # Loss Functions
        self.bce_criterion = nn.BCELoss()
        self.l1l_criterion = nn.L1Loss()
        self.l2l_criterion = nn.MSELoss()

        ##
        # Initialize input tensors.
        print('Initialize input tensors.')
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32)
        self.label = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32)
        self.labelf = torch.empty(size=(self.opt.batchsize,), dtype=torch.float32)
        self.gt    = torch.empty(size=(opt.batchsize,), dtype=torch.long)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize, self.opt.isize), dtype=torch.float32)
        self.real_label = 1
        self.fake_label = 0
        self.u=None
        self.n=None
        self.l1=None
        self.l2=torch.empty(size=(self.opt.batchsize, self.opt.nz,1,1), dtype=torch.float32)
        self.del1=None
        self.del2=None
        print('end')

        ##
        # Setup optimizer
        print('Setup optimizer')
        if self.opt.isTrain:
            self.neten.train()
            self.netde.train()
            self.netdv.train()
            self.netdl.train()
            self.netc.train()
            self.optimizer_en = optim.Adam(self.neten.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
            self.optimizer_de = optim.Adam(self.netde.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
            self.optimizer_dl = optim.Adam(self.netdl.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
            self.optimizer_dv = optim.Adam(self.netdv.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
            self.optimizer_c = optim.Adam(self.netc.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
            self.optimizer_l2 = optim.Adam([{'params':self.l2}], lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        print('end')
    ##
    def set_input(self, input):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        self.input.data.resize_(input[0].size()).copy_(input[0])
        self.gt.data.resize_(input[1].size()).copy_(input[1])

        # Copy the first batch as the fixed input.
        if self.total_steps == self.opt.batchsize:
            self.fixed_input.data.resize_(input[0].size()).copy_(input[0])
    ##
    def update_netd(self):
        """
        Update D network: Ladv = |f(real) - f(fake)|_2
        """
        ##
        # Feature Matching.
        self.netdv.zero_grad()
        self.netdl.zero_grad()
        # --
        self.out_dv0 = self.netdv(self.del2.detach())
        self.out_dv1 = self.netdv(self.input)
        self.out_dv = torch.cat([self.out_dv0, self.out_dv1], 0)

        self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.label.data.resize_(self.opt.batchsize).fill_(self.real_label)
        self.label_dv = torch.cat([self.labelf, self.label], 0)

        self.err_dv_bce = self.bce_criterion(self.out_dv, self.label_dv)

        self.out_dl0 = self.netdl(self.l1.detach())
        self.out_dl1 = self.netdl(self.l2)
        self.out_dl = torch.cat([self.out_dl0, self.out_dl1], 0)

        self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.label.data.resize_(self.opt.batchsize).fill_(self.real_label)
        self.label_dl = torch.cat([self.labelf, self.label], 0)

        self.err_dl_bce = self.bce_criterion(self.out_dl, self.label_dl)

        self.err_d=self.err_dl_bce+self.err_dv_bce
        self.err_d.backward(retain_graph=True)
        self.optimizer_dv.step()
        self.optimizer_dl.step()
    ##
    def update_netg(self):
        """
        # ============================================================ #
        # (2) Update G network: log(D(G(x)))  + ||G(x) - x||           #
        # ============================================================ #

        """
        self.neten.zero_grad()
        self.netde.zero_grad()
        # --
        # self.out_gv0 = self.netdv(self.input)
        # self.out_gv1 = self.netdv(self.del2)
        # self.out_gv = torch.cat([self.out_gv0, self.out_gv1], 0)
        self.out_gv1 = self.netdv(self.del2)

        # self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.label.data.resize_(self.opt.batchsize).fill_(self.real_label)
        # self.label_gv = torch.cat([self.labelf, self.label], 0)

        # self.err_gv_bce = self.bce_criterion(self.out_gv, self.label_gv)
        self.err_gv_bce = self.bce_criterion(self.out_gv1, self.label)

        # self.out_gl0 = self.netdl(self.l2)
        # self.out_gl1 = self.netdl(self.l1)
        # self.out_gl = torch.cat([self.out_gl0, self.out_gl1], 0)
        self.out_gl1 = self.netdl(self.l1)

        # self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.label.data.resize_(self.opt.batchsize).fill_(self.real_label)
        # self.label_gl = torch.cat([self.labelf, self.label], 0)

        # self.err_gl_bce = self.bce_criterion(self.out_gl, self.label_gl)
        self.err_gl_bce = self.bce_criterion(self.out_gl1, self.label)

        self.err_g_mse=self.l2l_criterion(self.input, self.del1)

        self.err_g = self.err_gl_bce + self.err_gv_bce+self.err_g_mse
        self.err_g.backward(retain_graph=True)
        self.optimizer_en.step()
        self.optimizer_de.step()
    ##
    def update_netc(self):
        self.netc.zero_grad()
        self.out_c0=self.netc(self.del2.detach())
        self.out_c1 = self.netc(self.input)
        # self.out_c1=self.netc(self.del1.detach())
        self.out_c=torch.cat([self.out_c0,self.out_c1],0)

        self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
        self.label.data.resize_(self.opt.batchsize).fill_(self.real_label)
        self.label_c=torch.cat([self.labelf,self.label],0)

        self.err_c_bce = self.bce_criterion(self.out_c, self.label_c)
        self.err_c_bce.backward()
        self.optimizer_c.step()

    def update_l2(self):
        for i in range(5):
            self.optimizer_l2.zero_grad()
            self.out_m=self.netc(self.del2)
            self.labelf.data.resize_(self.opt.batchsize).fill_(self.fake_label)
            self.err_m_bce = self.bce_criterion(self.out_m, self.labelf)
            self.err_m_bce.backward()
            self.optimizer_l2.step()
            self.del2=self.netde(self.l2)

    def optimize(self):
        """ Optimize netD and netG  networks.
        """
        self.u = np.random.uniform(-1, 1, (self.opt.batchsize, self.opt.nz, 1, 1))
        self.l2 = torch.from_numpy(self.u).float()
        self.n = torch.randn(self.opt.batchsize, self.opt.nc, self.opt.isize, self.opt.isize)
        self.l1 = self.neten(self.input + self.n)
        self.del1=self.netde(self.l1)
        self.del2=self.netde(self.l2)
        self.update_netc()
        self.update_netd()
        if self.opt.mine==True:
            self.update_l2()
        self.update_netg()


    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([('err_d', self.err_d.item()),
                              ('err_g', self.err_g.item()),
                              ('err_dv_bce', self.err_dv_bce.item()),
                              ('err_dl_bce', self.err_dl_bce.item()),
                              ('err_gv_bce', self.err_gv_bce.item()),
                              ('err_gl_l1l', self.err_gl_bce.item()),
                              ('err_g_mse', self.err_g_mse.item()),
                              ('err_c_bce', self.err_c_bce.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.del1.data
        fixed = self.netde(self.neten(self.fixed_input))[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        torch.save({'epoch': epoch + 1, 'state_dict': self.neten.state_dict()},
                   '%s/neten.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netde.state_dict()},
                   '%s/netde.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netdl.state_dict()},
                   '%s/netdl.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netdv.state_dict()},
                   '%s/netdv.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netc.state_dict()},
                   '%s/netc.pth' % (weight_dir))

    ##
    def train_epoch(self):
        """ Train the model for one epoch.
        """
        self.neten.train()
        self.netde.train()
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)

            self.optimize()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)

        print(">> Training model %s. Epoch %d/%d" % (self.name(), self.epoch+1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)
    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name())
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_epoch()
            res = self.test()
            # if res['AUC'] > best_auc:
            #     best_auc = res['AUC']
            #     self.save_weights(self.epoch)
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
            self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name())

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netc.pth".format(self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netc.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netc weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long)
            self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32)
            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            # print(self.dataloader['test'])
            # print(type(self.dataloader['test']))
            for i, data in enumerate(self.dataloader['test'], 0):
                #
                # print(data)
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.out_c = self.netc(self.input)

                time_o = time.time()
                self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + self.out_c.size(0)] = self.out_c
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+self.out_c.size(0)] = self.gt.reshape(self.out_c.size(0))
                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real, '%s/real_%03d.png' % (dst, i+1), normalize=True,nrow=4)
                    vutils.save_image(fake, '%s/fake_%03d.png' % (dst, i+1), normalize=True,nrow=4)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels, self.an_scores,metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio, performance)
            return performance
    def test1(self):
        """ Test GANomaly model.

        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/neten.pth".format(self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.neten.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netc weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long)
            self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),self.opt.nz), dtype=torch.float32)
            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            # print(self.dataloader['test'])
            # print(type(self.dataloader['test']))
            for i, data in enumerate(self.dataloader['test'], 0):
                #
                # print(data)
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.l1 = self.neten(self.input)

                time_o = time.time()
                self.an_scores[i * self.opt.batchsize: i * self.opt.batchsize + self.l1.size(0),:] = self.l1.reshape(self.l1.size(0),self.opt.nz)
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+self.l1.size(0)] = self.gt.reshape(self.l1.size(0))
                self.times.append(time_o - time_i)

            x=self.an_scores.numpy()
            y=self.gt_labels.numpy()
            return x,y
Ejemplo n.º 6
0
class BaseModel():
    """ Base Model for ganomaly
    """
    def __init__(self, opt, dataloader):
        ##
        # Seed for deterministic behavior
        self.seed(opt.manualseed)

        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device(
            "cuda:0" if self.opt.device != 'cpu' else "cpu")
        self.sqrtnz = int(self.opt.nz**0.5)

    ##
    def set_input(self, input: torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])
                self.visualizer.save_fixed_real_s(self.fixed_input)

    def z_set_input(self, net, input: torch.Tensor):

        with torch.no_grad():
            if net == 'i':
                self.i_input.resize_(input[0].size()).copy_(input[0])
                self.i_gt.resize_(input[1].size()).copy_(input[1])
                self.i_label.resize_(input[1].size())
                self.i_real_label.resize_(input[1].size()).copy_(input[1])
            if net == 'o':
                self.o_input.resize_(input[0].size()).copy_(input[0])
                self.o_gt.resize_(input[1].size()).copy_(input[1])
                self.o_label.resize_(input[1].size())
                self.o_real_label.resize_(input[1].size()).copy_(input[1])

    ##
    def seed(self, seed_value):
        """ Seed 
        
        Arguments:
            seed_value {int} -- [description]
        """
        # Check if seed is default value
        if seed_value == -1:
            return

        # Otherwise seed all functionality
        import random
        random.seed(seed_value)
        torch.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        np.random.seed(seed_value)
        torch.backends.cudnn.deterministic = True

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([('err_d', self.err_d.item()),
                              ('err_g', self.err_g.item()),
                              ('err_g_adv', self.err_g_adv.item()),
                              ('err_g_con', self.err_g_con.item()),
                              ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def z_get_errors(self, net):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """
        if net == 'i':
            errors = OrderedDict([('err_i', self.err_i.item())])
        if net == 'o':
            errors = OrderedDict([('err_o', self.err_o.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data
        # point
        fixed_reals = self.fixed_input.data
        # point
        return reals, fakes, fixed, fixed_reals

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """

        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train',
                                  'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netg.state_dict()
        }, '%s/netG.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netd.state_dict()
        }, '%s/netD.pth' % (weight_dir))

    ##
    def train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netg.train()
        if self.opt.strengthen:
            self.netd.train()  ## point
        epoch_iter = 0
        for data in tqdm(self.dataloader['train'],
                         leave=False,
                         total=len(self.dataloader['train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            # self.optimize()
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(
                        self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(
                        self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                # point
                reals, fakes, fixed, fixed_reals = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes,
                                                    fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(
                        reals, fakes, fixed, fixed_reals)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name, self.epoch + 1, self.opt.niter))
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(">> Training model %s." % self.name)
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.train_one_epoch()
            res = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                self.save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        if self.opt.strengthen:
            self.netg.eval()
        with torch.no_grad():

            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.d_pred = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.last_feature = torch.zeros(
                size=(len(self.dataloader['test'].dataset),
                      list(self.netd.children())[0][-3].out_channels,
                      list(self.netd.children())[0][-3].kernel_size[0],
                      list(self.netd.children())[0][-3].kernel_size[1]),
                dtype=torch.float32,
                device=self.device)

            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)
                d_pred, features = self.netd(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)
                self.d_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            d_pred.size(0)] = d_pred.reshape(d_pred.size(0))
                self.last_feature[
                    i * self.opt.batchsize:i * self.opt.batchsize +
                    error.size(0), :] = features.reshape(
                        error.size(0),
                        list(self.netd.children())[0][-3].out_channels,
                        list(self.netd.children())[0][-3].kernel_size[0],
                        list(self.netd.children())[0][-3].kernel_size[1])

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _, _ = self.get_current_images(
                    )  # point add attribute fixed_real
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)
            """
            data=[]
            feature = self.last_feature.cpu().numpy().reshape(self.last_feature.size()[0], -1)
            label = self.gt_labels.cpu().numpy().reshape(self.last_feature.size()[0], -1)            
            features_dir = './features'
            file_name = 'features_map.csv'
            feature_path = os.path.join(features_dir, file_name + '.txt')
            import pandas as pd
            feature.tolist()
            label.tolist()
            test = pd.DataFrame(data=feature)
            test.to_csv("./feature.csv", mode='a+', index=None, header=None)
            test = pd.DataFrame(data=label)
            test.to_csv("./label.csv", mode='a+', index=None, header=None)
            print('END')
            """

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))

            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.strengthen and self.opt.phase == 'test':
                t0 = threading.Thread(
                    target=self.visualizer.display_scores_histo,
                    name='histogram ',
                    args=(self.epoch, self.an_scores, self.gt_labels))
                t0.start()
                if self.opt.strengthen > 1:
                    t1 = threading.Thread(
                        target=self.visualizer.display_feature,
                        name='t-SNE visualizer',
                        args=(self.last_feature, self.gt_labels))
                    t2 = threading.Thread(
                        target=self.visualizer.display_latent,
                        name='latent LDA visualizer',
                        args=(self.latent_i, self.latent_o, self.gt_labels, 9,
                              1000, True))
                    t1.start()
                    t2.start()

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            if self.opt.classifier:
                self.z_dataloader = set_dataset(self.opt, self.latent_i,
                                                self.latent_o, self.gt_labels)
            return performance

    ##
    def z_save_weights(self, epoch):
        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train',
                                  'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netc_i.state_dict()
        }, '%s/netC_i.pth' % (weight_dir))
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.netc_o.state_dict()
        }, '%s/netC_o.pth' % (weight_dir))

    ##
    def z_train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best = 0

        # Train for niter epochs.
        print(">> Training model classifier")
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.z_train_one_epoch()
            res = self.z_test()
            if res[self.opt.z_metric] > best:
                best = res[self.opt.z_metric]
                self.z_save_weights(self.epoch)
            self.visualizer.print_current_performance(res, best,
                                                      self.opt.z_metric)
        self.visualizer.record_best(best, self.opt.z_metric,
                                    self.opt.abnormal_class,
                                    self.opt.manualseed, 'ganomaly_s')
        print(">> Training model %s.[Done]" % self.name)

    ##
    def z_train_one_epoch(self):
        """ Train the model for one epoch.
        """

        self.netc_i.train()
        self.netc_o.train()
        epoch_iter = 0

        for data in tqdm(self.z_dataloader['i_train'],
                         leave=False,
                         total=len(self.z_dataloader['i_train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.z_set_input('i', data)
            # self.optimize()
            self.z_optimize_params('i')

        for data in tqdm(self.z_dataloader['o_train'],
                         leave=False,
                         total=len(self.z_dataloader['o_train'])):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.z_set_input('o', data)
            # self.optimize()
            self.z_optimize_params('o')

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.z_get_errors('i')
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(
                        self.z_dataloader['i_train'].dataset)
                    self.visualizer.plot_current_errors(
                        self.epoch, counter_ratio, errors)

            # if self.total_steps % self.opt.save_image_freq == 0:
            #     # point
            #     reals, fakes, fixed, fixed_reals = self.get_current_images()
            #     self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
            #     if self.opt.display:
            #         self.visualizer.display_current_images(reals, fakes, fixed, fixed_reals)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name, self.epoch + 1, self.opt.niter))

    ##
    def z_test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        self.netd.eval()
        self.netc_i.eval()
        self.netc_o.eval()
        with torch.no_grad():

            # Load the weights of netg and netd.
            if self.opt.z_load_weights:
                d_path = "./output/{}/{}/train/weights/netD.pth".format(
                    self.name.lower(), self.opt.dataset)
                i_path = "./output/{}/{}/train/weights/netC_i.pth".format(
                    self.name.lower(), self.opt.dataset)
                o_path = "./output/{}/{}/train/weights/netC_o.pth".format(
                    self.name.lower(), self.opt.dataset)
                d_pretrained_dict = torch.load(d_path)['state_dict']
                i_pretrained_dict = torch.load(i_path)['state_dict']
                o_pretrained_dict = torch.load(o_path)['state_dict']

                try:
                    self.netd.load_state_dict(d_pretrained_dict)
                    self.netc_i.load_state_dict(i_pretrained_dict)
                    self.netc_o.load_state_dict(o_pretrained_dict)
                except IOError:
                    raise IOError("net weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.i_pred = torch.zeros(size=(len(
                self.z_dataloader['i_test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.o_pred = torch.zeros(size=(len(
                self.z_dataloader['o_test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.i_gt_labels = torch.zeros(size=(len(
                self.z_dataloader['i_test'].dataset), ),
                                           dtype=torch.long,
                                           device=self.device)
            self.o_gt_labels = torch.zeros(size=(len(
                self.z_dataloader['o_test'].dataset), ),
                                           dtype=torch.long,
                                           device=self.device)

            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            for i, data in enumerate(self.z_dataloader['i_test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.z_set_input('i', data)
                i_pred = self.netc_i(self.i_input)
                time_o = time.time()

                self.i_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            i_pred.size(0)] = i_pred.reshape(i_pred.size(0))
                self.i_gt_labels[i *
                                 self.opt.batchsize:i * self.opt.batchsize +
                                 i_pred.size(0)] = self.i_gt.reshape(
                                     self.i_gt.size(0))

                self.times.append(time_o - time_i)

            for i, data in enumerate(self.z_dataloader['o_test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.z_set_input('o', data)
                o_pred = self.netc_o(self.o_input)

                time_o = time.time()

                self.o_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            o_pred.size(0)] = o_pred.reshape(o_pred.size(0))
                self.o_gt_labels[i *
                                 self.opt.batchsize:i * self.opt.batchsize +
                                 o_pred.size(0)] = self.o_gt.reshape(
                                     self.o_gt.size(0))

                self.times.append(time_o - time_i)

                # Save test images.

            # print(auprc(self.i_gt_labels.cpu(), self.i_pred.cpu()))
            # print((self.i_gt_labels.cpu()[:10], self.i_pred.cpu())[:10])

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # auc, eer = roc(self.gt_labels, self.an_scores)

            self.pred_c = self.i_pred.cpu() * self.opt.w_i + \
                          self.o_pred.cpu() * self.opt.w_o

            # print(self.pred_c[:5])
            # print(self.i_gt_labels[:5])

            scores = evaluate(self.o_gt_labels.cpu(), self.pred_c,
                              self.opt.z_metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       (self.opt.z_metric, scores)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.z_dataloader['i_test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
Ejemplo n.º 7
0
class Skipganomaly:
    """Skip-GANomaly Class
    """
    @property
    def name(self):
        return 'skip-ganomaly'

    def __init__(self, opt, data=None):
        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.data = data
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.inf_dir = os.path.join(self.opt.outf, self.opt.name, 'inference')
        self.device = torch.device(
            "cuda:0" if self.opt.device != "cpu" else "cpu")

        # -- Misc attributes
        self.epoch = 0
        self.times = []
        self.total_steps = 0

        ##
        # Create and initialize networks from networks.py.

        self.netg = define_G(self.opt,
                             norm='batch',
                             use_dropout=False,
                             init_type='normal')
        self.netd = define_D(self.opt,
                             norm='batch',
                             use_sigmoid=False,
                             init_type='normal')

        ##
        #resume Training
        if self.opt.resume != '':
            print("\nLoading pre-trained networks.")
            self.opt.iter = torch.load(
                os.path.join(self.opt.resume, 'netG.pth'))['epoch']
            self.netg.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netG.pth'))['state_dict'])
            self.netd.load_state_dict(
                torch.load(os.path.join(self.opt.resume,
                                        'netD.pth'))['state_dict'])
            print("\tDone.\n")

        print(self.netg)
        print(self.netd)

        ##
        # Loss Functions
        self.l_adv = nn.BCELoss()
        self.l_con = nn.L1Loss()
        self.l_lat = l2_loss

        ##
        # Initialize input tensors.
        self.input = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.noise = torch.empty(size=(self.opt.batchsize, 3, self.opt.isize,
                                       self.opt.isize),
                                 dtype=torch.float32,
                                 device=self.device)
        self.label = torch.empty(size=(self.opt.batchsize, ),
                                 dtype=torch.float32,
                                 device=self.device)
        self.gt = torch.empty(size=(opt.batchsize, ),
                              dtype=torch.long,
                              device=self.device)
        self.fixed_input = torch.empty(size=(self.opt.batchsize, 3,
                                             self.opt.isize, self.opt.isize),
                                       dtype=torch.float32,
                                       device=self.device)
        self.real_label = torch.ones(size=(self.opt.batchsize, ),
                                     dtype=torch.float32,
                                     device=self.device)
        self.fake_label = torch.zeros(size=(self.opt.batchsize, ),
                                      dtype=torch.float32,
                                      device=self.device)

        ##
        # Setup optimizer
        if self.opt.phase == "train":
            self.netg.train()
            self.netd.train()
            self.optimizers = []
            self.optimizer_d = optim.Adam(self.netd.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizer_g = optim.Adam(self.netg.parameters(),
                                          lr=self.opt.lr,
                                          betas=(self.opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_d)
            self.optimizers.append(self.optimizer_g)
            self.schedulers = [
                get_scheduler(optimizer, opt) for optimizer in self.optimizers
            ]

    ##
    def set_input(self, input: torch.Tensor, noise: bool = False):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Add noise to the input.
            if noise: self.noise.data.copy_(torch.randn(self.noise.size()))

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """
        errors = OrderedDict([('err_d', self.err_d), ('err_g', self.err_g),
                              ('err_g_adv', self.err_g_adv),
                              ('err_g_con', self.err_g_con),
                              ('err_g_lat', self.err_g_lat)])

        return errors

    ##
    def reinit_d(self):
        """ Initialize the weights of netD
        """
        self.netd.apply(weights_init)
        print('Reloading d net')

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch: int, is_best: bool = False):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """
        name = self.opt.dataset if self.opt.dataset else self.opt.dataroot.split(
            "/")[-1]
        weight_dir = os.path.join(self.opt.outf, name, 'train', 'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        if is_best:
            torch.save({
                'epoch': epoch,
                'state_dict': self.netg.state_dict()
            }, f'{weight_dir}/netG_best.pth')
            torch.save({
                'epoch': epoch,
                'state_dict': self.netd.state_dict()
            }, f'{weight_dir}/netD_best.pth')
        else:
            torch.save({
                'epoch': epoch,
                'state_dict': self.netd.state_dict()
            }, f"{weight_dir}/netD_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'state_dict': self.netg.state_dict()
            }, f"{weight_dir}/netG_{epoch}.pth")

    def load_weights(self, epoch=None, is_best: bool = False, path=None):
        """ Load pre-trained weights of NetG and NetD

        Keyword Arguments:
            epoch {int}     -- Epoch to be loaded  (default: {None})
            is_best {bool}  -- Load the best epoch (default: {False})
            path {str}      -- Path to weight file (default: {None})

        Raises:
            Exception -- [description]
            IOError -- [description]
        """

        if epoch is None and is_best is False:
            raise Exception(
                'Please provide epoch to be loaded or choose the best epoch.')

        if is_best:
            fname_g = f"netG_best.pth"
            fname_d = f"netD_best.pth"
        else:
            fname_g = f"netG_{epoch}.pth"
            fname_d = f"netD_{epoch}.pth"

        if path is None:
            name = self.opt.dataset if self.opt.dataset else self.opt.dataroot.split(
                "/")[-1]
            path_g = f"{self.opt.outf}/{name}/train/weights/{fname_g}"
            path_d = f"{self.opt.outf}/{name}/train/weights/{fname_d}"

        else:
            path_g = path + "/" + fname_g
            path_d = path + "/" + fname_d

        # Load the weights of netg and netd.
        print('>> Loading weights...')

        if len(self.opt.gpu_ids) == 0:
            weights_g = torch.load(
                path_g,
                map_location=lambda storage, loc: storage)['state_dict']
            weights_d = torch.load(
                path_d,
                map_location=lambda storage, loc: storage)['state_dict']
        else:
            weights_g = torch.load(path_g)['state_dict']
            weights_d = torch.load(path_d)['state_dict']
        try:
            # create new OrderedDict that does not contain `module.`

            new_weights_g = OrderedDict()
            new_weights_d = OrderedDict()
            for k, v in weights_g.items():
                name = k[7:]  # remove `module.`
                new_weights_g[name] = v
            for k, v in weights_d.items():
                name = k[7:]  # remove `module.`
                new_weights_d[name] = v
            # load params
            if len(self.opt.gpu_ids) == 0:
                weights_g = new_weights_g
                weights_d = new_weights_d
            self.netg.load_state_dict(weights_g)
            self.netd.load_state_dict(weights_d)
        except IOError:
            raise IOError("netG weights not found")
        print('   Done.')

    def forward(self):
        self.forward_g()
        self.forward_d()

    def forward_g(self):
        """ Forward propagate through netG
        """
        #TODO: Check, why noised input is used
        self.fake = self.netg(self.input + self.noise)

    def forward_d(self):
        """ Forward propagate through netD
        """
        self.pred_real, self.feat_real = self.netd(self.input)
        self.pred_fake, self.feat_fake = self.netd(self.fake)

    def backward_g(self):
        """ Backpropagate netg
        """
        self.err_g_adv = self.opt.w_adv * self.l_adv(self.pred_fake,
                                                     self.real_label)
        self.err_g_con = self.opt.w_con * self.l_con(self.fake, self.input)
        self.err_g_lat = self.opt.w_lat * self.l_lat(
            self.feat_fake, self.feat_real)  # should be named discriminator
        if self.opt.verbose:
            print(f'err_g_adv: {str(self.err_g_adv)}')
            print(f'err_g_con: {str(self.err_g_con)}')
            print(f'err_g_lat: {str(self.err_g_lat)}')

        self.err_g = self.err_g_adv + self.err_g_con + self.err_g_lat
        self.err_g.backward(retain_graph=True)

    def backward_d(self):
        # Fake
        #print(f'pref_fake: {str(self.pred_fake)}')
        #print(f'self.fake_label: {str(self.fake_label)}')
        #print(f'self.pred_real: {str(self.pred_real)}')
        #print(f'self.real_label: {str(self.real_label)}')
        self.err_d_fake = self.l_adv(self.pred_fake, self.fake_label)

        # Real
        # pred_real, feat_real = self.netd(self.input)
        self.err_d_real = self.l_adv(self.pred_real, self.real_label)

        # Combine losses.
        # TODO: According to https://github.com/samet-akcay/skip-ganomaly/issues/18#issue-728932038 ... Check if lat loss has to be negative in discriminator backprob
        if self.opt.verbose:
            print(f'err_d_real: {str(self.err_d_real)}')
            print(f'err_d_fake: {str(self.err_d_fake)}')
            print(f'err_g_lat: {str(self.err_g_lat)}')
        self.err_d = self.err_d_real + self.err_d_fake + self.err_g_lat
        self.err_d.backward(retain_graph=True)

    def update_netg(self):
        """ Update Generator Network.
        """
        self.optimizer_g.zero_grad()
        self.backward_g()

    def update_netd(self):
        """ Update Discriminator Network.
        """
        self.optimizer_d.zero_grad()
        self.backward_d()

    ##
    def optimize_params(self):
        """ Optimize netD and netG  networks.
        """
        self.forward()
        self.update_netg()
        self.update_netd()
        self.optimizer_g.step()
        self.optimizer_d.step()
        if self.err_d < 1e-5: self.reinit_d()

    ##
    def train_one_epoch(self):
        """ Train the model for one epoch.
        """
        self.opt.phase = "train"
        self.netg.train()
        self.netd.train()
        epoch_iter = 0
        for data in tqdm(self.data.train,
                         leave=False,
                         total=len(self.data.train)):
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            self.optimize_params()
            reals, fakes, fixed = self.get_current_images()

            errors = self.get_errors()

            if self.opt.display:
                self.visualizer.plot_current_errors(self.epoch,
                                                    self.total_steps, errors)
                # Write images to tensorboard
                if self.total_steps % self.opt.save_image_freq == 0:
                    self.visualizer.display_current_images(
                        reals,
                        fakes,
                        fixed,
                        train_or_test="train",
                        global_step=self.total_steps)

            if self.total_steps % self.opt.save_image_freq == 0:
                self.visualizer.save_current_images(self.epoch, reals, fakes,
                                                    fixed)

        print(">> Training model %s. Epoch %d/%d" %
              (self.name, self.epoch + 1, self.opt.niter))

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0

        # Train for niter epochs.
        print(
            f">> Training {self.name} on {self.opt.dataset} to detect anomalies"
        )
        for self.epoch in range(self.opt.iter, self.opt.niter):
            self.train_one_epoch()
            res = self.test()
            if res['auc'] > best_auc:
                best_auc = res['auc']
                self.save_weights(self.epoch, is_best=True)
            self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)

    ##
    def test(self, plot_hist=True):
        """ Test GANomaly model.

        Args:
            data ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        self.netg.eval()
        self.netd.eval()
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.path_to_weights is not None:
                self.load_weights(path=self.opt.path_to_weights, is_best=True)

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(self.data.valid.dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(self.data.valid.dataset), ),
                                         dtype=torch.long,
                                         device=self.device)

            print("   Testing %s" % self.name)
            self.times = []
            total_steps_test = 0
            epoch_iter = 0
            i = 0
            for data in tqdm(self.data.valid,
                             leave=False,
                             total=len(self.data.valid)):
                total_steps_test += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()

                # Forward - Pass
                self.forward_for_testing(data)

                # Calculate the anomaly score.
                error = self.calculate_an_score()

                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                if self.opt.verbose:
                    print(f'an_scores: {str(self.an_scores)}')
                self.times.append(time_o - time_i)
                real, fake, fixed = self.get_current_images()

                if self.epoch * len(
                        self.data.valid
                ) + total_steps_test % self.opt.save_image_freq == 0:
                    self.visualizer.display_current_images(
                        real,
                        fake,
                        fixed,
                        train_or_test="test",
                        global_step=self.epoch * len(self.data.valid) +
                        total_steps_test)
                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst): os.makedirs(dst)

                    #iterate over them (real) and write anomaly score and ground truth on filename
                    vutils.save_image(real,
                                      '%s/real_%03d.png' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.png' % (dst, i + 1),
                                      normalize=True)
                i = i + 1
            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times * 1000)

            # Scale error vector between [0, 1]
            # self.an_scores = (self.an_scores - torch.min(self.an_scores))/(torch.max(self.an_scores) - torch.min(self.an_scores))
            if self.opt.verbose:
                print(f'scaled an_scores: {str(self.an_scores)}')
            y_trues = self.gt_labels.cpu()
            y_preds = self.an_scores.cpu()
            # Create data frame for scores and labels.
            performance, thresholds, y_preds_man, y_preds_auc = get_performance(
                y_trues=y_trues,
                y_preds=y_preds,
                manual_threshold=self.opt.decision_threshold)
            with open(
                    os.path.join(self.opt.outf,
                                 self.opt.phase + "_results.txt"), "a+") as f:
                f.write(str(performance))
                f.write("\n")
                f.close()
            self.visualizer.plot_histogram(
                y_trues=y_trues,
                y_preds=y_preds,
                threshold=performance["threshold"],
                save_path=os.path.join(
                    self.opt.outf,
                    "histogram_test" + str(self.epoch) + ".png"),
                tag="Histogram_Test",
                global_step=self.epoch)
            self.visualizer.plot_pr_curve(y_trues=y_trues,
                                          y_preds=y_preds,
                                          thresholds=thresholds,
                                          global_step=self.epoch,
                                          tag="PR_Curve_Test")
            self.visualizer.plot_roc_curve(
                y_trues=y_trues,
                y_preds=y_preds,
                global_step=self.epoch,
                tag="ROC_Curve_Test",
                save_path=os.path.join(self.opt.outf,
                                       "roc_test" + str(self.epoch) + ".png"))

            self.visualizer.plot_current_conf_matrix(
                self.epoch,
                performance["conf_matrix"],
                tag="Confusion_Matrix_Test",
                save_path=os.path.join(self.opt.outf,
                                       self.opt.phase + "_conf_matrix.png"))

            self.visualizer.plot_performance(self.epoch,
                                             0,
                                             performance,
                                             tag="Performance_Test")
            return performance

    def forward_for_testing(self, data):
        self.set_input(data)
        self.fake = self.netg(self.input)

        real_clas, self.feat_real = self.netd(self.input)
        fake_clas, self.feat_fake = self.netd(self.fake)

    def inference(self):
        self.netg.eval()
        self.netd.eval()
        with torch.no_grad():
            self.load_weights(path=self.opt.path_to_weights, is_best=True)

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.data.inference.dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.data.inference.dataset), ),
                                         dtype=torch.long,
                                         device=self.device)

            print("Starting Inference!")
            inf_time = None
            inf_times = []

            self.file_names = []
            for i, data in tqdm(enumerate(self.data.inference),
                                leave=False,
                                total=len(self.data.inference)):
                inf_start = time.time()

                # Forward - Pass
                self.forward_for_testing(data)

                # Calculate the anomaly score.
                error = self.calculate_an_score()

                inf_times.append(time.time() - inf_start)

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                if self.opt.verbose:
                    print(f'an_scores: {str(self.an_scores)}')
                real, fake, fixed = self.get_current_images()

                if i % self.opt.save_image_freq == 0:
                    self.visualizer.display_current_images(
                        real,
                        fake,
                        fixed,
                        train_or_test="test_inference",
                        global_step=i)

                self.file_names.append(data[2])
            # Measure inference time.

        # Scale error vector between [0, 1] TODO: does it work without normalizing?
        # self.an_scores = (self.an_scores - torch.min(self.an_scores))/(torch.max(self.an_scores) - torch.min(self.an_scores))
        if self.opt.verbose:
            print(f'scaled an_scores: {str(self.an_scores)}')
        y_trues = self.gt_labels.cpu()
        y_preds = self.an_scores.cpu()
        inf_time = sum(inf_times)
        print(f'Inference time: {inf_time} secs')
        print(f'Inference time / individual: {inf_time/len(y_trues)} secs')

        # Create data frame for scores and labels.
        performance, thresholds, y_preds_man, y_preds_auc = get_performance(
            y_trues=y_trues,
            y_preds=y_preds,
            manual_threshold=self.opt.decision_threshold)
        with open(os.path.join(self.opt.outf, self.opt.phase + "_results.txt"),
                  "w") as f:
            f.write(str(performance))
            f.close()
        self.visualizer.plot_histogram(y_trues=y_trues,
                                       y_preds=y_preds,
                                       threshold=performance["threshold"],
                                       save_path=os.path.join(
                                           self.opt.outf,
                                           "histogram_inference.png"),
                                       tag="Histogram_Inference")
        self.visualizer.plot_pr_curve(y_trues=y_trues,
                                      y_preds=y_preds,
                                      thresholds=thresholds,
                                      global_step=1,
                                      tag="PR_Curve_Inference")
        self.visualizer.plot_roc_curve(y_trues=y_trues,
                                       y_preds=y_preds,
                                       global_step=1,
                                       tag="ROC_Curve_Inference",
                                       save_path=os.path.join(
                                           self.opt.outf, "roc_inference.png"))

        self.visualizer.plot_current_conf_matrix(
            1,
            performance["conf_matrix"],
            tag="Confusion_Matrix_Inference",
            save_path=os.path.join(self.opt.outf,
                                   self.opt.phase + "_conf_matrix.png"))
        if self.opt.decision_threshold:
            self.visualizer.plot_current_conf_matrix(
                2,
                performance["conf_matrix_man"],
                save_path=os.path.join(self.opt.outf, self.opt.phase +
                                       "_conf_matrix_man.png"))
            self.visualizer.plot_histogram(
                y_trues=y_trues,
                y_preds=y_preds,
                threshold=performance["manual_threshold"],
                global_step=2,
                save_path=os.path.join(self.opt.outf,
                                       "histogram_inference_man.png"),
                tag="Histogram_Inference")
            write_inference_result(file_names=self.file_names,
                                   y_trues=y_trues,
                                   y_preds=y_preds_man,
                                   outf=os.path.join(
                                       self.opt.outf,
                                       "classification_result_man.json"))
        self.visualizer.plot_performance(1,
                                         0,
                                         performance,
                                         tag="Performance_Inference")

        write_inference_result(file_names=self.file_names,
                               y_trues=y_trues,
                               y_preds=y_preds_auc,
                               outf=os.path.join(self.opt.outf,
                                                 "classification_result.json"))
        ##
        # RETURN
        return performance

    def calculate_an_score(self):
        si = self.input.size()
        sz = self.feat_real.size()
        rec = (self.input - self.fake).view(si[0], si[1] * si[2] * si[3])
        lat = (self.feat_real - self.feat_fake).view(sz[0],
                                                     sz[1] * sz[2] * sz[3])
        rec = torch.mean(torch.pow(rec, 2), dim=1)
        lat = torch.mean(torch.pow(lat, 2), dim=1)
        #print("lat", lat)
        #print("rec", rec)
        if self.opt.verbose:
            print(f'rec: {str(rec)}')
            print(f'lat: {str(lat)}')
        error = 0.9 * rec + 0.1 * lat

        return error
Ejemplo n.º 8
0
class BaseModel():
    """ Base Model for ganomaly
    """
    def __init__(self, opt, dataloader):
        ##
        # Seed for deterministic behavior
        self.seed(opt.manualseed)

        # Initalize variables.
        self.opt = opt
        self.visualizer = Visualizer(opt)
        self.dataloader = dataloader
        self.trn_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
        self.tst_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
        self.device = torch.device("cuda:0" if self.opt.device != 'cpu' else "cpu")

    ##
    def set_input(self, input:torch.Tensor):
        """ Set input and ground truth

        Args:
            input (FloatTensor): Input data for batch i.
        """
        with torch.no_grad():
            self.input.resize_(input[0].size()).copy_(input[0])
            self.gt.resize_(input[1].size()).copy_(input[1])
            self.label.resize_(input[1].size())

            # Copy the first batch as the fixed input.
            if self.total_steps == self.opt.batchsize:
                self.fixed_input.resize_(input[0].size()).copy_(input[0])

    ##
    def seed(self, seed_value):
        """ Seed 
        
        Arguments:
            seed_value {int} -- [description]
        """
        # Check if seed is default value
        if seed_value == -1:
            return

        # Otherwise seed all functionality
        import random
        random.seed(seed_value)
        torch.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        np.random.seed(seed_value)
        torch.backends.cudnn.deterministic = True

    ##
    def get_errors(self):
        """ Get netD and netG errors.

        Returns:
            [OrderedDict]: Dictionary containing errors.
        """

        errors = OrderedDict([
            ('err_d', self.err_d.item()),
            ('err_g', self.err_g.item()),
            ('err_g_adv', self.err_g_adv.item()),
            ('err_g_con', self.err_g_con.item()),
            ('err_g_enc', self.err_g_enc.item())])

        return errors

    ##
    def get_current_images(self):
        """ Returns current images.

        Returns:
            [reals, fakes, fixed]
        """

        reals = self.input.data
        fakes = self.fake.data
        fixed = self.netg(self.fixed_input)[0].data

        return reals, fakes, fixed

    ##
    def save_weights(self, epoch):
        """Save netG and netD weights for the current epoch.

        Args:
            epoch ([int]): Current epoch number.
        """
        weight_dir = os.path.join(self.opt.outf, self.opt.name, 'train', 'weights')
        if not os.path.exists(weight_dir): os.makedirs(weight_dir)
        print(weight_dir)
        torch.save({'epoch': epoch + 1, 'state_dict': self.netg.state_dict()},
                   '%s/netG.pth' % (weight_dir))
        torch.save({'epoch': epoch + 1, 'state_dict': self.netd.state_dict()},
                   '%s/netD.pth' % (weight_dir))

    ##
    def train_one_epoch(self, epochNUM):
        """ Train the model for one epoch.
        """

        self.netg.train()
        epoch_iter = 0
        ii = 0
        num = len(self.dataloader['train'])
        for data in tqdm(self.dataloader['train'], leave=False, total=len(self.dataloader['train'])):
            self.opt.signalInfo.emit(10 + 0.8 * 85 * (epochNUM / self.opt.niter)*(ii/num),"")
            self.total_steps += self.opt.batchsize
            epoch_iter += self.opt.batchsize

            self.set_input(data)
            # self.optimize()
            self.optimize_params()

            if self.total_steps % self.opt.print_freq == 0:
                errors = self.get_errors()
                if self.opt.display:
                    counter_ratio = float(epoch_iter) / len(self.dataloader['train'].dataset)
                    self.visualizer.plot_current_errors(self.epoch, counter_ratio, errors)

            if self.total_steps % self.opt.save_image_freq == 0:
                reals, fakes, fixed = self.get_current_images()
                self.visualizer.save_current_images(self.epoch, reals, fakes, fixed)
                if self.opt.display:
                    self.visualizer.display_current_images(reals, fakes, fixed)
            ii += 1

        message = ">> Training model %s. Epoch %d/%d" % (self.name, self.epoch+1, self.opt.niter)
        print(message)
        #self.opt.showText.append(message+"\n");
        # self.visualizer.print_current_errors(self.epoch, errors)

    ##
    def train(self):
        """ Train the model
        """

        ##
        # TRAIN
        self.total_steps = 0
        best_auc = 0
        best_info = None

        # Train for niter epochs.
        print(">> Training model %s." % self.name)
        self.opt.signalInfo.emit(-1,">> Training model {}.".format(self.name))
        i = 0
        for self.epoch in range(self.opt.iter, self.opt.niter):
            # Train for one epoch
            self.opt.signalInfo.emit(-1,'正在进行第{}个epoch的训练....'.format(i + 1))
            num = self.train_one_epoch(i+1)
            i += 1
            # self.save_weights(self.epoch)
            self.opt.signalInfo.emit(10 + 0.8*85* (i / self.opt.niter),'第{}个epoch的训练完毕!\n正在对该epoch进行测试....'.format(i))
            res,info = self.test()
            if res['AUC'] > best_auc:
                best_auc = res['AUC']
                best_info = info
                self.save_weights(self.epoch)
            infoSTR = ""
            for key,value in info.items():
                infoSTR += str(key)+":"+str(value)+"\n"
            self.opt.signalInfo.emit(10 + 85* (i/ self.opt.niter), '测试完毕!\n第{}个epoch训练结果:\n{}'.format(i, infoSTR))
            # self.visualizer.print_current_performance(res, best_auc)
        print(">> Training model %s.[Done]" % self.name)
        self.opt.signalInfo.emit(-1,">> Training model {}.[Done]".format(self.name))
        # dict_info = {}
        # dict_info['minVal'] = 0.1
        # dict_info['maxVal'] = 0.8
        # dict_info['proline'] = 0.40
        # dict_info['auc'] = 0.93
        # dict_info['Avg Run Time (ms/batch)'] = 9
        # best_info = dict_info
        return best_info

    ##
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'
            #self.opt.showProcess.setValue(80)
            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device)
            self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long,    device=self.device)
            self.latent_i  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)
            self.latent_o  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz)
                self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True)
                    vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True)


            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            print(torch.min(self.an_scores))
            print(torch.max(self.an_scores))
            maxNUM = torch.max(self.an_scores)
            minNUM = torch.min(self.an_scores)
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)



            # -------------- 处理阈值 ------------------
            print('-------------- 处理阈值 ------------------')
            print(len(self.gt_labels))
            plt.ion()
            scores = {}
            ##plt.ion()
            # Create data frame for scores and labels.
            scores['scores'] = self.an_scores
            scores['labels'] = self.gt_labels
            hist = pd.DataFrame.from_dict(scores)
            #hist.to_csv("histogram.csv")

            # Filter normal and abnormal scores.
            abn_scr = hist.loc[hist.labels == 1]['scores']
            nrm_scr = hist.loc[hist.labels == 0]['scores']
            # Create figure and plot the distribution.
            ##fig, axes = plt.subplots(figsize=(4, 4))

            b = []
            c = []

            # for i in range(1000):
            #     b.append(nrm_scr[i])
            # for j in range(1000, 3011):
            #     c.append(abn_scr[j])
            print('asasddda')
            print(len(nrm_scr))
            print(len(abn_scr))

            for i in nrm_scr:
                b.append(i)

            for j in abn_scr:
                c.append(j)

            ##sns.distplot(nrm_scr, label=r'Normal Scores', color='r', bins=100, hist=True)

            ##sns.distplot(abn_scr, label=r'Abnormal Scores', color='b', bins=100, hist=True)

            nrm = np.zeros((50), dtype=np.int)
            minfix = 0.4
            abn = np.zeros((50), dtype=np.int)
            abmin = 30
            for k in np.arange(0, 1, 0.02):
                kint = int(k * 50)
                for j in range(len(nrm_scr)):
                    if b[j] >= k and b[j] < (k + 0.02):
                        nrm[kint] = nrm[kint] + 1
                for j in range(len(abn_scr)):
                    if c[j] >= k and c[j] < (k + 0.02):
                        abn[kint] = abn[kint] + 1
            print(nrm)
            print(abn)

            # startInd = 3
            # for k in range(0,20):
            #     if abs(nrm[k] - abn[k]) <= 3:
            #         continue
            #     else:
            #         startInd = k
            # max_dist = (len(nrm) + len(abn))*0.28
            # for k in range(startInd, 20):
            #     if abs(nrm[k] - abn[k]) < 5:
            #         #max_dist = abs(nrm[k] - abn[k])
            #         minfix = round((k / 20) + 0.02, 3)
            #         break

            # for k in range(3, 17):
            #     # print(nrm[k])
            #     # print(abn[k])
            #     # print('----')
            #     if abs(nrm[k]-abn[k]) > abmin and not (nrm[k] == 0 and abn[k] == 0):
            #         abmin = abs(nrm[k] - abn[k])
            #         minfix = round((k / 20) + 0.02, 3)
            max_dist = (len(nrm) + len(abn)) * 0.25
            for k in range(0,50):
                num1 = np.sum(nrm[0:k])
                num2 = np.sum(abn[k::])
                if (num1 + num2) >= max_dist:
                    minfix = round((k / 50) + 0.05, 3)
                    max_dist = num1+num2

            proline = minfix
            print(proline)

            print(self.gt_labels[0:20])
            print(self.an_scores[0:20])

            print('-------------  处理阈值 END --------------')
            # -------------  处理阈值 END --------------


            auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio, performance)


            #  --- 写入文件 ---

            dict_info = {}
            dict_info['minVal'] = float(minNUM.item())
            dict_info['maxVal'] = float(maxNUM.item())
            dict_info['proline'] = float(proline)
            dict_info['auc'] = float(auc)
            dict_info['Avg Run Time (ms/batch)'] = float(self.times)

            #self.opt.showText.append(str(performance));
            #self.opt.showProcess.setValue(100)
            return performance, dict_info



    def FinalTest(self, minVal, maxVal,threshold=0.2):
        path = "./output/{}/{}/train/weights/netG.pth".format('ganomaly', self.opt.dataset)
        print('***'*10)
        print(path)
        print('***' * 10)
        pretrained_dict = torch.load(path)['state_dict']
        #self.opt.showText.append('Loading Weights...')
        self.opt.signalInfo.emit(-1, '加载权重...')
        try:
            self.netg.load_state_dict(pretrained_dict)
        except IOError:
            raise IOError("netG weights not found")
        print('   Loaded weights.')
        #self.opt.showText.append('LoadedWeights')
        self.opt.signalInfo.emit(5,"权重加载完毕!\n正在加载图片....")
        #self.opt.showText.append('正在加载图片...')
        path2 = self.opt.dataroot + '/final_test/'
        transform = transforms.Compose([transforms.Resize(self.opt.isize),
                                        transforms.CenterCrop(self.opt.isize),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

        testdata = ImageFolder(os.path.join(path2), transform)

        #print(testdata)
        testdata2 = torch.utils.data.DataLoader(dataset=testdata,
                                                batch_size=1,
                                                shuffle=False,
                                                num_workers=int(self.opt.workers),
                                                drop_last=True)
        #print(testdata2)

        self.opt.signalInfo.emit(10, "图片加载完毕!\n正在进行检测....")
        testFilesName = os.listdir(self.opt.dataroot+'/final_test/0')
        testFilesRes = {}
        testFilesResNor = {}
        testFilesResAbn = {}

        for i, data2 in enumerate(testdata2, 0):
            self.set_input(data2)
            self.fake, latent_i, latent_o = self.netg(self.input)
            error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
            #print(error)
            testscore = (error - minVal) / (maxVal - minVal)
            testFilesRes[testFilesName[i]] = testscore.item()
            if testscore.item() >= threshold:
                testFilesResAbn[testFilesName[i]] = testscore.item()
            else:
                testFilesResNor[testFilesName[i]] = testscore.item()
            self.opt.signalInfo.emit(10+(i+1)/len(testFilesName)*88,"")
            #self.opt.showProcess.setValue()

        # print(testFilesRes)
        # print(testFilesResNor)
        # print(testFilesResAbn)
        self.opt.signalInfo.emit(100, "图片检测完毕!")
        #self.opt.signal.emit(len(testFilesResNor), len(testFilesResAbn))
        torch.cuda.empty_cache()



        return testFilesResNor,testFilesResAbn