Пример #1
0
def main(folderb, folderw, black, white):

    player_black, checkpoint_black = load_player(str(folderb), black)
    player_white, checkpoint_white = load_player(str(folderw), white)

    print("Loaded Black player with iteration " +
          str(checkpoint_black['total_ite']))
    print("Loaded White player with iteration " +
          str(checkpoint_white['total_ite']))

    ## Start method for PyTorch
    multiprocessing.set_start_method('spawn')

    evaluate(player_black, player_white)
Пример #2
0
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)

            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
Пример #3
0
    def test_1(self):

        with torch.no_grad():

            self.total_steps_test = 0
            epoch_iter = 0
            print('test')
            label = torch.zeros(size=(10000, ),
                                dtype=torch.long,
                                device=self.device)
            pre = torch.zeros(size=(10000, ),
                              dtype=torch.float32,
                              device=self.device)
            pre_real = torch.zeros(size=(10000, ),
                                   dtype=torch.float32,
                                   device=self.device)

            self.relation = torch.zeros(size=(10000, ),
                                        dtype=torch.float32,
                                        device=self.device)
            self.relation_img = torch.zeros(size=(10000, ),
                                            dtype=torch.float32,
                                            device=self.device)

            self.classifiear = torch.zeros(size=(10000, 4),
                                           dtype=torch.float32,
                                           device=self.device)
            self.opt.phase = 'test'
            for i, (x, y, z) in enumerate(self.test_loader):
                self.input = Variable(x)
                self.label_rrr = Variable(z)

                self.input = self.input.to(self.device)
                self.label_rrr = self.label_rrr.to(self.device)

                size = int(self.input.size(0) / 4)
                input_1 = torch.empty(size=(size, 3, self.opt.isize,
                                            self.opt.isize),
                                      dtype=torch.float32,
                                      device=self.device)
                input_2 = torch.empty(size=(size, 3, self.opt.isize,
                                            self.opt.isize),
                                      dtype=torch.float32,
                                      device=self.device)
                input_3 = torch.empty(size=(size, 3, self.opt.isize,
                                            self.opt.isize),
                                      dtype=torch.float32,
                                      device=self.device)
                input_4 = torch.empty(size=(size, 3, self.opt.isize,
                                            self.opt.isize),
                                      dtype=torch.float32,
                                      device=self.device)

                classfiear_real_1 = self.netc(self.input)
                classfiear_real = F.softmax(classfiear_real_1, dim=1)

                prediction_real = -(torch.log(classfiear_real))
                for j in range(size):
                    input_1[j] = self.input[j * 4]
                    input_2[j] = self.input[j * 4 + 1]
                    input_3[j] = self.input[j * 4 + 2]
                    input_4[j] = self.input[j * 4 + 3]
                output_1 = self.netg(input_1)
                output_2 = self.netg(input_2)
                output_3 = self.netg(input_3)
                output_4 = self.netg(input_4)
                classifiear_real = self.netc(input_1)

                classfiear_11 = self.netc(output_1)
                classfiear_21 = self.netc(output_2)
                classfiear_31 = self.netc(output_3)
                classfiear_41 = self.netc(output_4)

                classfiear_1 = F.softmax(classfiear_11, dim=1)
                classfiear_2 = F.softmax(classfiear_21, dim=1)
                classfiear_3 = F.softmax(classfiear_31, dim=1)
                classfiear_4 = F.softmax(classfiear_41, dim=1)

                prediction_1 = -(torch.log(classfiear_1))
                prediction_2 = -(torch.log(classfiear_2))
                prediction_3 = -(torch.log(classfiear_3))
                prediction_4 = -(torch.log(classfiear_4))

                aaaa = prediction_1.size(0)
                self.classifiear[i * 16:i * 16 + aaaa] = classfiear_11

                # prediction = prediction * (-1/4)

                label_z = torch.zeros(size=(aaaa, ),
                                      dtype=torch.long,
                                      device=self.device)
                pre_score = torch.zeros(size=(aaaa, ),
                                        dtype=prediction_1.dtype,
                                        device=self.device)
                pre_score_real = torch.zeros(size=(aaaa, ),
                                             dtype=prediction_1.dtype,
                                             device=self.device)

                distance_img = torch.mean(torch.pow((output_1 - input_1), 2),
                                          -1)
                distance_img = torch.mean(torch.mean(distance_img, -1), -1)

                distance = torch.mean(
                    torch.pow((classifiear_real - classfiear_11), 2), -1)

                self.relation[i * 16:i * 16 +
                              distance.size(0)] = distance.reshape(
                                  distance.size(0))
                self.relation_img[i * 16:i * 16 +
                                  distance.size(0)] = distance_img.reshape(
                                      distance.size(0))

                for k in range(aaaa):
                    label_z[k] = self.label_rrr[k * 4]
                    pre_score[k] = (prediction_1[k, 0] + prediction_2[k, 1] +
                                    prediction_3[k, 2] +
                                    prediction_4[k, 3]) / 4
                    pre_score_real[k] = (prediction_real[k * 4, 0] +
                                         prediction_real[k * 4 + 1, 1] +
                                         prediction_real[k * 4 + 2, 2] +
                                         prediction_real[k * 4 + 3, 3]) / 4

                label[i * 16:i * 16 + aaaa] = label_z
                pre[i * 16:i * 16 + aaaa] = pre_score
                pre_real[i * 16:i * 16 + aaaa] = pre_score_real

            D = pre + self.relation * 0.2
            D_real = pre_real + self.relation * 0.2

            aaaa = self.classifiear.cpu().numpy()
            np.savetxt('./output/log.txt', aaaa)
            bbbb = label.cpu().numpy()
            np.savetxt('./output/label.txt', bbbb)

            mu = torch.mul(pre, self.relation)
            mu_real = torch.mul(pre_real, self.relation)

            auc_mu_fake = evaluate(label, mu, metric=self.opt.metric)
            auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric)
            auc_d_fake = evaluate(label, D, metric=self.opt.metric)
            auc_d_real = evaluate(label, D_real, metric=self.opt.metric)
            auc_c_fake = evaluate(label, pre, metric=self.opt.metric)
            auc_c_real = evaluate(label, pre_real, metric=self.opt.metric)
            auc_r = evaluate(label, self.relation, metric=self.opt.metric)
            auc_r_img = evaluate(label,
                                 self.relation_img,
                                 metric=self.opt.metric)

            print('Train mul_real ROC AUC Score: %f  mu_fake: %f' %
                  (auc_mu_real, auc_mu_fake))
            print('Train add_real ROC AUC Score: %f  add_fake: %f' %
                  (auc_d_real, auc_d_fake))
            print('Train class_real ROC AUC Score: %f class_fake: %f' %
                  (auc_c_real, auc_c_fake))

            print('Train recon ROC AUC Score: %f recon_img:%f' %
                  (auc_r, auc_r_img))
            print('test done')
Пример #4
0
    def test(self):
        with torch.no_grad():

            self.opt.load_weights = True
            self.epoch1 = 1
            self.epoch2 = 200

            self.total_steps = 0
            epoch_iter = 0
            print('test')
            label = torch.zeros(size=(10000, ),
                                dtype=torch.long,
                                device=self.device)
            pre = torch.zeros(size=(10000, ),
                              dtype=torch.float32,
                              device=self.device)
            pre_real = torch.zeros(size=(10000, ),
                                   dtype=torch.float32,
                                   device=self.device)

            self.relation = torch.zeros(size=(10000, ),
                                        dtype=torch.float32,
                                        device=self.device)
            self.distance = torch.zeros(size=(40000, ),
                                        dtype=torch.float32,
                                        device=self.device)
            self.relation_img = torch.zeros(size=(10000, ),
                                            dtype=torch.float32,
                                            device=self.device)
            self.distance_img = torch.zeros(size=(40000, ),
                                            dtype=torch.float32,
                                            device=self.device)
            self.opt.phase = 'test'

            for i, (x, y, z) in enumerate(self.test_loader):
                self.input = Variable(x)
                self.label_r = Variable(z)
                self.total_steps += self.opt.batchsize

                self.argument_image_rotation_plus(self.input)
                self.label_r = self.label_r.to(self.device)

                classfiear_real_1 = self.netc(self.img_real)
                classfiear_real = F.softmax(classfiear_real_1, dim=1)

                prediction_real = -(torch.log(classfiear_real))
                self.fake = self.netg(self.img_real)
                classfiear_1 = self.netc(self.fake)

                classfiear = F.softmax(classfiear_1, dim=1)

                prediction = -(torch.log(classfiear))
                aaaa = (prediction.size(0) / 4)
                aaaa = int(aaaa)
                # prediction = prediction * (-1/4)

                label_z = torch.zeros(size=(aaaa, ),
                                      dtype=torch.long,
                                      device=self.device)
                pre_score = torch.zeros(size=(aaaa, ),
                                        dtype=prediction.dtype,
                                        device=self.device)
                pre_score_real = torch.zeros(size=(aaaa, ),
                                             dtype=prediction.dtype,
                                             device=self.device)

                self.img_trans = self.trans_img(self.fake.cpu())

                distance_img = torch.mean(
                    torch.pow((self.fake - self.img_real), 2), -1)
                distance_img = torch.mean(torch.mean(distance_img, -1), -1)

                if self.total_steps % self.opt.save_image_freq == 0:
                    reals, fakes, trans = self.get_current_images()
                    self.visualizer.save_test_images(i, reals, fakes, trans)
                    if self.opt.display:
                        self.visualizer.display_test_images(
                            reals, fakes, trans)

                # distance = torch.mean(torch.pow((classfiear_1 - classfiear_real_1), 2), -1)

                # self.distance[i * self.opt.batchsize: i * self.opt.batchsize + distance.size(0)] = distance.reshape(
                #     distance.size(0))

                self.distance_img[i * 64:i * 64 +
                                  distance_img.size(0)] = distance_img.reshape(
                                      distance_img.size(0))

                for k in range(aaaa):
                    # label_z[k] = self.label_r[k * 4]
                    pre_score[k] = (prediction[k * 4, 0] +
                                    prediction[k * 4 + 1, 1] +
                                    prediction[k * 4 + 2, 2] +
                                    prediction[k * 4 + 3, 3]) / 4
                    pre_score_real[k] = (prediction_real[k * 4, 0] +
                                         prediction_real[k * 4 + 1, 1] +
                                         prediction_real[k * 4 + 2, 2] +
                                         prediction_real[k * 4 + 3, 3]) / 4

                label[i * self.opt.batchsize:i * self.opt.batchsize +
                      aaaa] = self.label_r
                pre[i * self.opt.batchsize:i * self.opt.batchsize +
                    aaaa] = pre_score
                pre_real[i * self.opt.batchsize:i * self.opt.batchsize +
                         aaaa] = pre_score_real

            for j in range(10000):
                # self.relation[j] = self.distance[j*4]
                self.relation_img[j] = self.distance_img[j * 4]

            # D = pre + self.relation * 0.2
            # D_real = pre_real + self.relation * 0.2
            #
            # mu = torch.mul(pre, self.relation)
            # mu_real = torch.mul(pre_real, self.relation)
            aaaa = self.relation_img.cpu().numpy()
            np.savetxt('./output/log.txt', aaaa)
            bbbb = label.cpu().numpy()
            np.savetxt('./output/label.txt', bbbb)

            # auc_mu_fake = evaluate(label, mu, metric=self.opt.metric)
            # auc_mu_real = evaluate(label, mu_real, metric=self.opt.metric)
            # auc_d_fake = evaluate(label, D, metric=self.opt.metric)
            # auc_d_real = evaluate(label, D_real, metric=self.opt.metric)
            auc_c_fake = evaluate(label, pre, metric=self.opt.metric)
            auc_c_real = evaluate(label, pre_real, metric=self.opt.metric)
            # auc_r = evaluate(label, self.relation, metric=self.opt.metric)
            auc_r_img = evaluate(label,
                                 self.relation_img,
                                 metric=self.opt.metric)

            performance = OrderedDict([('AUC_R', auc_r_img),
                                       ('AUC_C_real', auc_c_real),
                                       ('AUC_C_fake', auc_c_fake)])
            print('test done')
            return performance
Пример #5
0
def train(args):
    use_cuda = args["--cuda"]

    if args["--seed"] is None:
        args["--seed"] = np.random.randint(1e4)
    else:
        args["--seed"] = int(args["--seed"])

    print("Using seed = {}".format(args["--seed"]))
    torch.manual_seed(args["--seed"])
    np.random.seed(seed=args["--seed"])

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    val_coco_gt = get_val_coco_ground_truth(args["--data"])

    train_dataloader = get_train_dataloader(
        args["--data"],
        batch_size=args["--batch-size"]
    )
    val_dataloader = get_val_dataloader(
        args["--data"],
        batch_size=args["--batch-size"]
    )
    ssd = SSD300(backbone=ResNet(args["--backbone"], args["--backbone-path"]))
    loss_func = Loss(dboxes)
    # args.learning_rate * args.N_gpu * (args.batch_size / 32)
    learning_rate = args["--lr"]  # * (args["--batch-size"] / 32)
    start_epoch = 0

    if use_cuda:
        ssd.cuda()
        loss_func.cuda()

    optimizer = torch.optim.Adam(
        tencent_trick(ssd),
        lr=learning_rate,
        weight_decay=args["--wd"]
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=args["--multistep"],
        gamma=0.1
    )

    # checkpoint
    if args["--checkpoint"] is not None:
        ch_path = args["--checkpoint"]
        if os.path.isfile(ch_path):
            load_checkpoint(ssd, ch_path)
            checkpoint = torch.load(ch_path)

            start_epoch = checkpoint['epoch']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            sys.exit(1)

    # ---/ Modes
    total_time = 0  # ms

    # Logs
    writer = SummaryWriter(log_dir="./runs")

    # Evaluate
    if args["--mode"] == "evaluate":
        acc = evaluate(
            ssd, val_dataloader, encoder, val_coco_gt, is_cuda=use_cuda
        )
        print('Model precision {} mAP'.format(acc))
        return

    # Create
    if args["--mode"] == "create":
        save_model(
            './models/epoch_{}.pt'.format(epoch),
            ssd, epoch, optimizer, scheduler
        )
        return

    # Train
    for epoch in range(start_epoch, args["--epochs"]):
        start_epoch_time = time.time()
        loss_avg = train_loop(
            ssd, loss_func,
            epoch,
            optimizer,
            train_dataloader, val_dataloader,
            encoder,
            # iteration,
            # logger
            is_cuda=use_cuda
        )
        scheduler.step()
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args["--save"]:
            print("saving model...")
            save_model(
                './models/epoch_{}.pt'.format(epoch),
                ssd, epoch + 1, optimizer, scheduler
            )

        # calculate val precision
        acc = evaluate(
            ssd, val_dataloader, encoder, val_coco_gt, is_cuda=use_cuda
        )

        # log
        writer.add_scalar("Loss Avg/train", loss_avg, epoch)
        writer.add_scalar("Accuracy/val [mAP]", acc, epoch)
        writer.add_scalar("Time Epoch, ms", end_epoch_time, epoch)

    print('total training time: {}'.format(total_time))
Пример #6
0
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        if self.opt.strengthen:
            self.netg.eval()
        with torch.no_grad():

            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(
                    self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.latent_i = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.latent_o = torch.zeros(size=(len(
                self.dataloader['test'].dataset), self.opt.nz),
                                        dtype=torch.float32,
                                        device=self.device)
            self.d_pred = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.last_feature = torch.zeros(
                size=(len(self.dataloader['test'].dataset),
                      list(self.netd.children())[0][-3].out_channels,
                      list(self.netd.children())[0][-3].kernel_size[0],
                      list(self.netd.children())[0][-3].kernel_size[1]),
                dtype=torch.float32,
                device=self.device)

            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)
                d_pred, features = self.netd(self.input)

                error = torch.mean(torch.pow((latent_i - latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_i.reshape(
                                  error.size(0), self.opt.nz)
                self.latent_o[i * self.opt.batchsize:i * self.opt.batchsize +
                              error.size(0), :] = latent_o.reshape(
                                  error.size(0), self.opt.nz)
                self.d_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            d_pred.size(0)] = d_pred.reshape(d_pred.size(0))
                self.last_feature[
                    i * self.opt.batchsize:i * self.opt.batchsize +
                    error.size(0), :] = features.reshape(
                        error.size(0),
                        list(self.netd.children())[0][-3].out_channels,
                        list(self.netd.children())[0][-3].kernel_size[0],
                        list(self.netd.children())[0][-3].kernel_size[1])

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _, _ = self.get_current_images(
                    )  # point add attribute fixed_real
                    vutils.save_image(real,
                                      '%s/real_%03d.eps' % (dst, i + 1),
                                      normalize=True)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.eps' % (dst, i + 1),
                                      normalize=True)
            """
            data=[]
            feature = self.last_feature.cpu().numpy().reshape(self.last_feature.size()[0], -1)
            label = self.gt_labels.cpu().numpy().reshape(self.last_feature.size()[0], -1)            
            features_dir = './features'
            file_name = 'features_map.csv'
            feature_path = os.path.join(features_dir, file_name + '.txt')
            import pandas as pd
            feature.tolist()
            label.tolist()
            test = pd.DataFrame(data=feature)
            test.to_csv("./feature.csv", mode='a+', index=None, header=None)
            test = pd.DataFrame(data=label)
            test.to_csv("./label.csv", mode='a+', index=None, header=None)
            print('END')
            """

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (
                torch.max(self.an_scores) - torch.min(self.an_scores))

            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])

            if self.opt.strengthen and self.opt.phase == 'test':
                t0 = threading.Thread(
                    target=self.visualizer.display_scores_histo,
                    name='histogram ',
                    args=(self.epoch, self.an_scores, self.gt_labels))
                t0.start()
                if self.opt.strengthen > 1:
                    t1 = threading.Thread(
                        target=self.visualizer.display_feature,
                        name='t-SNE visualizer',
                        args=(self.last_feature, self.gt_labels))
                    t2 = threading.Thread(
                        target=self.visualizer.display_latent,
                        name='latent LDA visualizer',
                        args=(self.latent_i, self.latent_o, self.gt_labels, 9,
                              1000, True))
                    t1.start()
                    t2.start()

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            if self.opt.classifier:
                self.z_dataloader = set_dataset(self.opt, self.latent_i,
                                                self.latent_o, self.gt_labels)
            return performance
Пример #7
0
    def z_test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        self.netd.eval()
        self.netc_i.eval()
        self.netc_o.eval()
        with torch.no_grad():

            # Load the weights of netg and netd.
            if self.opt.z_load_weights:
                d_path = "./output/{}/{}/train/weights/netD.pth".format(
                    self.name.lower(), self.opt.dataset)
                i_path = "./output/{}/{}/train/weights/netC_i.pth".format(
                    self.name.lower(), self.opt.dataset)
                o_path = "./output/{}/{}/train/weights/netC_o.pth".format(
                    self.name.lower(), self.opt.dataset)
                d_pretrained_dict = torch.load(d_path)['state_dict']
                i_pretrained_dict = torch.load(i_path)['state_dict']
                o_pretrained_dict = torch.load(o_path)['state_dict']

                try:
                    self.netd.load_state_dict(d_pretrained_dict)
                    self.netc_i.load_state_dict(i_pretrained_dict)
                    self.netc_o.load_state_dict(o_pretrained_dict)
                except IOError:
                    raise IOError("net weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.i_pred = torch.zeros(size=(len(
                self.z_dataloader['i_test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.o_pred = torch.zeros(size=(len(
                self.z_dataloader['o_test'].dataset), ),
                                      dtype=torch.float32,
                                      device=self.device)
            self.i_gt_labels = torch.zeros(size=(len(
                self.z_dataloader['i_test'].dataset), ),
                                           dtype=torch.long,
                                           device=self.device)
            self.o_gt_labels = torch.zeros(size=(len(
                self.z_dataloader['o_test'].dataset), ),
                                           dtype=torch.long,
                                           device=self.device)

            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            for i, data in enumerate(self.z_dataloader['i_test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.z_set_input('i', data)
                i_pred = self.netc_i(self.i_input)
                time_o = time.time()

                self.i_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            i_pred.size(0)] = i_pred.reshape(i_pred.size(0))
                self.i_gt_labels[i *
                                 self.opt.batchsize:i * self.opt.batchsize +
                                 i_pred.size(0)] = self.i_gt.reshape(
                                     self.i_gt.size(0))

                self.times.append(time_o - time_i)

            for i, data in enumerate(self.z_dataloader['o_test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.z_set_input('o', data)
                o_pred = self.netc_o(self.o_input)

                time_o = time.time()

                self.o_pred[i * self.opt.batchsize:i * self.opt.batchsize +
                            o_pred.size(0)] = o_pred.reshape(o_pred.size(0))
                self.o_gt_labels[i *
                                 self.opt.batchsize:i * self.opt.batchsize +
                                 o_pred.size(0)] = self.o_gt.reshape(
                                     self.o_gt.size(0))

                self.times.append(time_o - time_i)

                # Save test images.

            # print(auprc(self.i_gt_labels.cpu(), self.i_pred.cpu()))
            # print((self.i_gt_labels.cpu()[:10], self.i_pred.cpu())[:10])

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # auc, eer = roc(self.gt_labels, self.an_scores)

            self.pred_c = self.i_pred.cpu() * self.opt.w_i + \
                          self.o_pred.cpu() * self.opt.w_o

            # print(self.pred_c[:5])
            # print(self.i_gt_labels[:5])

            scores = evaluate(self.o_gt_labels.cpu(), self.pred_c,
                              self.opt.z_metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       (self.opt.z_metric, scores)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.z_dataloader['i_test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
Пример #8
0
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """
        with torch.no_grad():
            # print(00)
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                # print(11)
                path = "./output/{}/{}/train/weights/netc.pth".format(
                    self.name().lower(), self.opt.dataset)
                # path1 = "./output/{}/{}/train/weights/netc.pth".format(self.name().lower(), self.opt.dataset)
                # path2 = "./output/{}/{}/train/weights/neten.pth".format(self.name().lower(), self.opt.dataset)
                # path3 = "./output/{}/{}/train/weights/netde.pth".format(self.name().lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']
                # pretrained_dict1 = torch.load(path1)['state_dict']
                # pretrained_dict2 = torch.load(path2)['state_dict']
                # pretrained_dict3 = torch.load(path3)['state_dict']

                try:
                    self.netc.load_state_dict(pretrained_dict)
                    # self.netc.load_state_dict(pretrained_dict1)
                    # self.neten.load_state_dict(pretrained_dict2)
                    # self.netde.load_state_dict(pretrained_dict3)

                except IOError:
                    raise IOError("netc weights not found")
                print('   Loaded weights.')
            # print(22)
            self.opt.phase = 'test'

            # Create big error tensor for the test set.
            self.gt_labels = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.long,
                                         device=self.device)
            self.an_scores = torch.zeros(size=(len(
                self.dataloader['test'].dataset), ),
                                         dtype=torch.float32,
                                         device=self.device)
            # print("   Testing model %s." % self.name())
            self.times = []
            self.total_steps = 0
            epoch_iter = 0
            # print(self.dataloader['test'])
            # print(type(self.dataloader['test']))
            # print(33)
            for i, data in enumerate(self.dataloader['test'], 0):
                #
                # print(data)
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.l1 = self.neten(self.input)
                self.del1 = self.netde(self.l1)
                self.out_c = self.netc(self.del1)
                # self.out_c = self.netc(self.input)

                time_o = time.time()
                self.an_scores[i * self.opt.batchsize:i * self.opt.batchsize +
                               self.out_c.size(0)] = self.out_c
                self.gt_labels[i * self.opt.batchsize:i * self.opt.batchsize +
                               self.out_c.size(0)] = self.gt.reshape(
                                   self.out_c.size(0))
                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test',
                                       'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real,
                                      '%s/real_%03d.png' % (dst, i + 1),
                                      normalize=True,
                                      nrow=4)
                    vutils.save_image(fake,
                                      '%s/fake_%03d.png' % (dst, i + 1),
                                      normalize=True,
                                      nrow=4)

            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # print(44)
            # auc, eer = roc(self.gt_labels, self.an_scores)
            auc = evaluate(self.gt_labels,
                           self.an_scores,
                           metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times),
                                       ('AUC', auc)])
            # print(auc)

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(
                    self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio,
                                                 performance)
            return performance
Пример #9
0
    def test(self):
        """ Test GANomaly model.

        Args:
            dataloader ([type]): Dataloader for the test set

        Raises:
            IOError: Model weights not found.
        """

        with torch.no_grad():
            # Load the weights of netg and netd.
            if self.opt.load_weights:
                path = "./output/{}/{}/train/weights/netG.pth".format(self.name.lower(), self.opt.dataset)
                pretrained_dict = torch.load(path)['state_dict']

                try:
                    self.netg.load_state_dict(pretrained_dict)
                except IOError:
                    raise IOError("netG weights not found")
                print('   Loaded weights.')

            self.opt.phase = 'test'
            #self.opt.showProcess.setValue(80)
            # Create big error tensor for the test set.
            self.an_scores = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.float32, device=self.device)
            self.gt_labels = torch.zeros(size=(len(self.dataloader['test'].dataset),), dtype=torch.long,    device=self.device)
            self.latent_i  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)
            self.latent_o  = torch.zeros(size=(len(self.dataloader['test'].dataset), self.opt.nz), dtype=torch.float32, device=self.device)

            # print("   Testing model %s." % self.name)
            self.times = []
            self.total_steps = 0
            epoch_iter = 0

            for i, data in enumerate(self.dataloader['test'], 0):
                self.total_steps += self.opt.batchsize
                epoch_iter += self.opt.batchsize
                time_i = time.time()
                self.set_input(data)
                self.fake, latent_i, latent_o = self.netg(self.input)

                error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
                time_o = time.time()

                self.an_scores[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = error.reshape(error.size(0))
                self.gt_labels[i*self.opt.batchsize : i*self.opt.batchsize+error.size(0)] = self.gt.reshape(error.size(0))
                self.latent_i [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_i.reshape(error.size(0), self.opt.nz)
                self.latent_o [i*self.opt.batchsize : i*self.opt.batchsize+error.size(0), :] = latent_o.reshape(error.size(0), self.opt.nz)

                self.times.append(time_o - time_i)

                # Save test images.
                if self.opt.save_test_images:
                    dst = os.path.join(self.opt.outf, self.opt.name, 'test', 'images')
                    if not os.path.isdir(dst):
                        os.makedirs(dst)
                    real, fake, _ = self.get_current_images()
                    vutils.save_image(real, '%s/real_%03d.eps' % (dst, i+1), normalize=True)
                    vutils.save_image(fake, '%s/fake_%03d.eps' % (dst, i+1), normalize=True)


            # Measure inference time.
            self.times = np.array(self.times)
            self.times = np.mean(self.times[:100] * 1000)

            # Scale error vector between [0, 1]
            print(torch.min(self.an_scores))
            print(torch.max(self.an_scores))
            maxNUM = torch.max(self.an_scores)
            minNUM = torch.min(self.an_scores)
            self.an_scores = (self.an_scores - torch.min(self.an_scores)) / (torch.max(self.an_scores) - torch.min(self.an_scores))
            # auc, eer = roc(self.gt_labels, self.an_scores)



            # -------------- 处理阈值 ------------------
            print('-------------- 处理阈值 ------------------')
            print(len(self.gt_labels))
            plt.ion()
            scores = {}
            ##plt.ion()
            # Create data frame for scores and labels.
            scores['scores'] = self.an_scores
            scores['labels'] = self.gt_labels
            hist = pd.DataFrame.from_dict(scores)
            #hist.to_csv("histogram.csv")

            # Filter normal and abnormal scores.
            abn_scr = hist.loc[hist.labels == 1]['scores']
            nrm_scr = hist.loc[hist.labels == 0]['scores']
            # Create figure and plot the distribution.
            ##fig, axes = plt.subplots(figsize=(4, 4))

            b = []
            c = []

            # for i in range(1000):
            #     b.append(nrm_scr[i])
            # for j in range(1000, 3011):
            #     c.append(abn_scr[j])
            print('asasddda')
            print(len(nrm_scr))
            print(len(abn_scr))

            for i in nrm_scr:
                b.append(i)

            for j in abn_scr:
                c.append(j)

            ##sns.distplot(nrm_scr, label=r'Normal Scores', color='r', bins=100, hist=True)

            ##sns.distplot(abn_scr, label=r'Abnormal Scores', color='b', bins=100, hist=True)

            nrm = np.zeros((50), dtype=np.int)
            minfix = 0.4
            abn = np.zeros((50), dtype=np.int)
            abmin = 30
            for k in np.arange(0, 1, 0.02):
                kint = int(k * 50)
                for j in range(len(nrm_scr)):
                    if b[j] >= k and b[j] < (k + 0.02):
                        nrm[kint] = nrm[kint] + 1
                for j in range(len(abn_scr)):
                    if c[j] >= k and c[j] < (k + 0.02):
                        abn[kint] = abn[kint] + 1
            print(nrm)
            print(abn)

            # startInd = 3
            # for k in range(0,20):
            #     if abs(nrm[k] - abn[k]) <= 3:
            #         continue
            #     else:
            #         startInd = k
            # max_dist = (len(nrm) + len(abn))*0.28
            # for k in range(startInd, 20):
            #     if abs(nrm[k] - abn[k]) < 5:
            #         #max_dist = abs(nrm[k] - abn[k])
            #         minfix = round((k / 20) + 0.02, 3)
            #         break

            # for k in range(3, 17):
            #     # print(nrm[k])
            #     # print(abn[k])
            #     # print('----')
            #     if abs(nrm[k]-abn[k]) > abmin and not (nrm[k] == 0 and abn[k] == 0):
            #         abmin = abs(nrm[k] - abn[k])
            #         minfix = round((k / 20) + 0.02, 3)
            max_dist = (len(nrm) + len(abn)) * 0.25
            for k in range(0,50):
                num1 = np.sum(nrm[0:k])
                num2 = np.sum(abn[k::])
                if (num1 + num2) >= max_dist:
                    minfix = round((k / 50) + 0.05, 3)
                    max_dist = num1+num2

            proline = minfix
            print(proline)

            print(self.gt_labels[0:20])
            print(self.an_scores[0:20])

            print('-------------  处理阈值 END --------------')
            # -------------  处理阈值 END --------------


            auc = evaluate(self.gt_labels, self.an_scores, metric=self.opt.metric)
            performance = OrderedDict([('Avg Run Time (ms/batch)', self.times), ('AUC', auc)])

            if self.opt.display_id > 0 and self.opt.phase == 'test':
                counter_ratio = float(epoch_iter) / len(self.dataloader['test'].dataset)
                self.visualizer.plot_performance(self.epoch, counter_ratio, performance)


            #  --- 写入文件 ---

            dict_info = {}
            dict_info['minVal'] = float(minNUM.item())
            dict_info['maxVal'] = float(maxNUM.item())
            dict_info['proline'] = float(proline)
            dict_info['auc'] = float(auc)
            dict_info['Avg Run Time (ms/batch)'] = float(self.times)

            #self.opt.showText.append(str(performance));
            #self.opt.showProcess.setValue(100)
            return performance, dict_info
Пример #10
0
    def test(self):
        self.model.eval()

        errs = []
        predicts = []
        gts = []

        with torch.no_grad():

            pbar = tqdm(self.dataloader['test'],
                        leave=True,
                        ncols=100,
                        total=len(self.dataloader['test']))
            for i, data in enumerate(pbar):

                input_, real_, gt_, lb_ = (d.to('cuda') for d in data)
                predict_ = self.model(input_)
                t_pre_ = threshold(predict_)
                m_pre_ = morphology_proc(t_pre_)

                gts.append(gt_.permute(0, 2, 3, 4, 1).cpu().numpy())
                predicts.append(m_pre_.permute(0, 2, 3, 4, 1).cpu().numpy())

                errs.append(self.loss(predict_, gt_).item())

                self.color_video_dict.update({
                    'test/input-real':
                    torch.cat([input_, real_], dim=3),
                })
                self.gray_video_dict.update({
                    'test/mask-pre-th-mor':
                    torch.cat([gt_, predict_, t_pre_, m_pre_], dim=3)
                })

                pbar.set_description("[TEST Epoch %d/%d]" %
                                     (self.epoch, self.args.ep))

            gts = np.asarray(np.stack(gts), dtype=np.int32).flatten()
            predicts = np.asarray(np.stack(predicts)).flatten()
            roc = evaluate(gts,
                           predicts,
                           self.best_roc,
                           self.epoch,
                           self.save_root_dir,
                           metric='roc')
            pr = evaluate(gts,
                          predicts,
                          self.best_pr,
                          self.epoch,
                          self.save_root_dir,
                          metric='pr')
            f1 = evaluate(gts, predicts, metric='f1_score')

            if roc > self.best_roc:
                self.best_roc = roc
                self.save_weights('ROC', self.best_roc)
            elif pr > self.best_pr:
                self.best_pr = pr
                self.save_weights('PR', self.best_pr)

            self.errors_dict.update({'loss/err/test': np.mean(errs)})
            self.score_dict.update({
                "score/roc": roc,
                "score/pr": pr,
                "score/f1": f1,
            })
Пример #11
0
    def test(self):
        err_g_adv_s_ = []
        err_g_adv_t_ = []
        err_g_adv_ = []
        err_g_con_ = []
        err_g_pre_ = []
        err_g_ = []

        err_d_real_s_ = []
        err_d_real_t_ = []
        err_d_fake_s_ = []
        err_d_fake_t_ = []
        err_d_real_ = []
        err_d_fake_ = []
        err_d_ = []

        predicts = []
        gts = []

        with torch.no_grad():
            # Test
            pbar = tqdm(self.dataloader['test'],
                        leave=True,
                        ncols=100,
                        total=len(self.dataloader['test']))
            for i, data in enumerate(pbar):
                # set test data
                input, real, gt, lb = (d.to('cuda') for d in data)
                # NetG
                predict_ = self.netg(input)  # Reconstract self.input
                t_pre_ = threshold(predict_.detach())
                m_pre_ = morphology_proc(t_pre_)
                gts.append(gt.permute(0, 2, 3, 4, 1).cpu().numpy())
                predicts.append(m_pre_.permute(0, 2, 3, 4, 1).cpu().numpy())
                # NetD
                # calc Optical Flow
                gt_3ch_ = gray2rgb(gt)
                pre_3ch_ = gray2rgb(predict_)
                gt_flow_ = video_to_flow(gt_3ch_.detach()).to('cuda')
                pre_flow_ = video_to_flow(pre_3ch_).to('cuda')
                # get disc output
                s_pred_real_, s_feat_real_, t_pred_real_, t_feat_real_ \
                                        = self.netd(gt_3ch_, gt_flow_.detach())
                s_pred_fake_, s_feat_fake_, t_pred_fake_, t_feat_fake_ \
                                        = self.netd(pre_3ch_.detach(), pre_flow_.detach())
                # Calc err_g
                err_g_adv_s_.append(
                    self.l_adv(s_feat_real_, s_feat_fake_).item())
                err_g_adv_t_.append(
                    self.l_adv(t_feat_real_, t_feat_fake_).item())
                err_g_adv_.append(err_g_adv_s_[-1] + err_g_adv_t_[-1])
                err_g_con_.append(self.l_con(predict_, gt).item())
                err_g_.append(err_g_adv_t_[-1] * self.args.w_adv +
                              err_g_con_[-1] * self.args.w_con)
                # Calc err_d
                err_d_real_s_.append(
                    self.l_bce(s_pred_real_, self.real_label).item())
                err_d_real_t_.append(
                    self.l_bce(t_pred_real_, self.real_label).item())
                err_d_fake_s_.append(
                    self.l_bce(s_pred_fake_, self.gout_label).item())
                err_d_fake_t_.append(
                    self.l_bce(t_pred_fake_, self.gout_label).item())
                err_d_real_.append(
                    (err_d_real_s_[-1] + err_d_real_t_[-1]) * 0.5)
                err_d_fake_.append(
                    (err_d_fake_s_[-1] + err_d_fake_t_[-1]) * 0.5)
                err_d_.append((err_d_real_[-1] + err_d_fake_[-1]) * 0.5)

                # test video summary
                self.color_video_dict.update({
                    'test/input-real':
                    torch.cat([input, real], dim=3),
                })
                self.gray_video_dict.update({
                    'test/gt-pre-th-morph':
                    torch.cat([gt, predict_, t_pre_, m_pre_], dim=3)
                })
                self.hist_dict.update({
                    "test/inp": input,
                    "test/gt": gt,
                    "test/predict": predict_,
                    "test/t_pre": t_pre_,
                    "test/m_pre": m_pre_
                })

                pbar.set_description("[TEST  Epoch %d/%d]" %
                                     (self.epoch + 1, self.args.ep))

            # AUC
            gts = np.asarray(np.stack(gts), dtype=np.int32).flatten()
            predicts = np.asarray(np.stack(predicts)).flatten()
            roc = evaluate(gts,
                           predicts,
                           self.best_roc,
                           self.epoch,
                           self.save_root_dir,
                           metric='roc')
            pr = evaluate(gts,
                          predicts,
                          self.best_pr,
                          self.epoch,
                          self.save_root_dir,
                          metric='pr')
            f1 = evaluate(gts, predicts, metric='f1_score')
            if roc > self.best_roc:
                self.best_roc = roc
                self.save_weights('roc')
            elif pr > self.best_pr:
                self.best_pr = pr
                self.save_weights('pr')

            # Update summary of loss ans auc
            self.score_dict.update({
                "score/roc": roc,
                "score/pr": pr,
                "score/f1": f1
            })
            self.errors_dict.update({
                'd/err_d_real_s/test':
                np.mean(err_d_real_s_),
                'd/err_d_real_t/test':
                np.mean(err_d_real_t_),
                'd/err_d_fake_s/test':
                np.mean(err_d_fake_s_),
                'd/err_d_fake_t/test':
                np.mean(err_d_fake_t_),
                'd/err_d_real/test':
                np.mean(err_d_real_),
                'd/err_d_fake/test':
                np.mean(err_d_fake_),
                'd/err_d/test':
                np.mean(err_d_),
                'g/err_g_adv_s/test':
                np.mean(err_g_adv_s_),
                'g/err_g_adv_t/test':
                np.mean(err_g_adv_t_),
                'g/err_g_adv/test':
                np.mean(err_g_adv_),
                'g/err_g_con/test':
                np.mean(err_g_con_),
                'g/err_g/test':
                np.mean(err_g_)
            })
Пример #12
0
    def test(self):
        self.netg.eval()
        self.netd.eval()

        gen_loss_ = []
        dis_loss_real_ = []
        dis_loss_fake_ = []
        dis_loss_ = []

        gts = []
        predicts = []

        with torch.no_grad():
            # Test
            pbar = tqdm(self.dataloader['test'], leave=True, ncols=100, total=len(self.dataloader['test']))
            for i, data in enumerate(pbar):
                
                # set test data 
                input, real, gt, lb = (d.to('cuda') for d in data)
                 
                # NetD

                dis_real_ = self.netd(real)[0].view(-1)
                dis_loss_real_.append(self.loss(dis_real_, self.ones_label).item())

                z = torch.randn(self.args.batchsize, 100, device='cuda')
                gen_fake_ = self.netg(z)
                dis_fake_ = self.netd(gen_fake_.detach())[0].view(-1)
                dis_loss_fake_.append(self.loss(dis_fake_, self.zeros_label).item())
                dis_loss_.append(dis_loss_real_[-1] + dis_loss_fake_[-1])

                # NetG
                dis_fake_ = self.netd(gen_fake_)[0].view(-1)
                gen_loss_.append(self.loss(dis_fake_, self.ones_label).item())
                
                predict_ = predict_forg(gen_fake_, real)
                t_pre_ = threshold(predict_.detach())
                m_pre_ = morphology_proc(t_pre_)

                gts.append(gt.permute(0, 2, 3, 4, 1).cpu().numpy())
                predicts.append(predict_.permute(0, 2, 3, 4, 1).cpu().numpy())
                
                # test video summary
                self.color_video_dict.update({
                        'test/input-real-gen': torch.cat([input, real, gen_fake_], dim=3),
                    })
                self.gray_video_dict.update({
                        'test/gt-pre-th-morph': torch.cat([gt, predict_, t_pre_, m_pre_], dim=3)
                    })
                self.hist_dict.update({
                    "test/inp": input,
                    "test/gt": gt,
                    "test/gen": gen_fake_,
                    "test/predict": predict_,
                    "test/t_pre": t_pre_,
                    "test/m_pre": m_pre_
                    })

                pbar.set_description("[TEST  Epoch %d/%d]" % (self.epoch+1, self.args.ep))

            # AUC
            gts = np.asarray(np.stack(gts), dtype=np.int32).flatten()
            predicts = np.asarray(np.stack(predicts)).flatten()
            roc = evaluate(gts, predicts, self.best_roc, self.epoch, self.save_root_dir, metric='roc')
            pr = evaluate(gts, predicts, self.best_pr, self.epoch, self.save_root_dir, metric='pr')
            f1 = evaluate(gts, predicts, metric='f1_score')
            if roc > self.best_roc: 
                self.best_roc = roc
                self.save_weights('roc')
            elif pr > self.best_pr:
                self.best_pr = pr
                self.save_weights('pr')
                
            # Update summary of loss ans auc
            self.score_dict.update({
                    "score/roc": roc,
                    "score/pr": pr,
                    "score/f1": f1
                })
            self.errors_dict.update({
                        'd/err_d/test': np.mean(gen_loss_),
                        'd/err_g/test': np.mean(dis_loss_)
                        })