def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='val_test', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=args.aux, pretrained=True, pretrained_base=False,
                                            local_rank=args.local_rank,
                                            norm_layer=BatchNorm2d).to(self.device)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.local_rank], output_device=args.local_rank)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)
コード例 #2
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            pretrained=True,
                                            pretrained_base=False)
        if args.distributed:
            self.model = self.model.module
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)
コード例 #3
0
    def __init__(self, config):
        self.config = config
        self.run_config = config['run_config']
        self.optim_config = config['optim_config']
        self.data_config = config['data_config']
        self.model_config = config['model_config']

        self.device = torch.device(self.run_config["device"])

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(
            self.data_config['dataset_name'],
            root=self.data_config['dataset_root'],
            split='test',
            mode='test',
            transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False,
                                        self.run_config['distributed'])
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=10,
                                                    drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=4,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if self.run_config[
            'distributed'] else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=self.model_config['model'],
            dataset=self.data_config['dataset_name'],
            backbone=self.model_config['backbone'],
            aux=self.optim_config['aux'],
            jpu=self.model_config['jpu'],
            norm_layer=BatchNorm2d,
            root=run_config['path']['eval_model_root'],
            pretrained=run_config['eval_model'],
            pretrained_base=False,
            local_rank=self.run_config['local_rank']).to(self.device)

        if self.run_config['distributed']:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.run_config['local_rank']],
                output_device=self.run_config['local_rank'])
        elif len(run_config['gpu_ids']) > 1:
            assert torch.cuda.is_available()
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)
コード例 #4
0
    def __init__(self, args):
        self.args = args

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
        train_dataset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            drop_last=True,
                                            shuffle=True)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          drop_last=False,
                                          shuffle=False)

        # create network
        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=args.aux, norm_layer=nn.BatchNorm2d).to(args.device)

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_label=-1).to(args.device)

        # for multi-GPU
        # if torch.cuda.is_available():
        #     self.model = DataParallelModel(self.model).cuda()
        #     self.criterion = DataParallelCriterion(self.criterion).cuda()

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(torch.load(args.resume, map_location=lambda storage, loc: storage))

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader), power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0
コード例 #5
0
def eval(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    outdir = 'test_result'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # image transform
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
    ])

    # dataset and dataloader
    test_dataset = get_segmentation_dataset(args.dataset,
                                            split='val',
                                            mode='testval',
                                            transform=input_transform)

    test_loader = data.DataLoader(dataset=test_dataset,
                                  batch_size=1,
                                  shuffle=False)

    # create network
    model = get_segmentation_model(model=args.model,
                                   dataset=args.dataset,
                                   backbone=args.backbone,
                                   aux=args.aux,
                                   pretrained=True,
                                   pretrained_base=False).to(device)
    print('Finished loading model!')

    metric = SegmentationMetric(test_dataset.num_class)

    model.eval()
    for i, (image, label) in enumerate(test_loader):
        image = image.to(device)

        with torch.no_grad():
            outputs = model(image)

            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()
            label = label.numpy()

            metric.update(pred, label)
            pixAcc, mIoU = metric.get()
            print('Sample %d, validation pixAcc: %.3f%%, mIoU: %.3f%%' %
                  (i + 1, pixAcc * 100, mIoU * 100))

            if args.save_result:
                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, args.dataset)
                mask.save(os.path.join(outdir, 'seg_{}.png'.format(i)))
コード例 #6
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               **data_kwargs)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))
        ###...
        # self.model.to(self.device)
        if args.mutilgpu:
            self.model = nn.DataParallel(self.model, device_ids=args.gpu_ids)
        ##....
        self.metric = SegmentationMetric(val_dataset.num_class)
コード例 #7
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               **data_kwargs)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            pretrained=True,
                                            pretrained_base=False,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        fps_sum = 0.0
        for i, (image, target, filename) in enumerate(self.val_loader):
            start = time.time()
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            end = time.time()
            fps = 1.0 / (end - start)
            fps_sum = fps_sum + fps
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}, FPS: {:.3f}"
                .format(i + 1, pixAcc * 100, mIoU * 100, fps))

            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, self.args.dataset)
                mask.save(
                    os.path.join(outdir,
                                 os.path.splitext(filename[0])[0] + '.png'))
                #danet  显存不足
                #if i + 1 > 302: break
        avg_fps = fps_sum / len(self.val_loader)
        logger.info("avgFPS: {:.3f}".format(avg_fps))
        synchronize()
コード例 #8
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            pretrained=True,
                                            pretrained_base=False)
        if args.distributed:
            self.model = self.model.module
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = self.model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, args.dataset)
                mask.save(
                    os.path.join(outdir,
                                 os.path.splitext(filename[0])[0] + '.png'))
        synchronize()
コード例 #9
0
class Evaluator(object):
    def __init__(self, config):
        self.config = config
        self.run_config = config['run_config']
        self.optim_config = config['optim_config']
        self.data_config = config['data_config']
        self.model_config = config['model_config']

        self.device = torch.device(self.run_config["device"])

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(
            self.data_config['dataset_name'],
            root=self.data_config['dataset_root'],
            split='test',
            mode='test',
            transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False,
                                        self.run_config['distributed'])
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=10,
                                                    drop_last=False)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=4,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if self.run_config[
            'distributed'] else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=self.model_config['model'],
            dataset=self.data_config['dataset_name'],
            backbone=self.model_config['backbone'],
            aux=self.optim_config['aux'],
            jpu=self.model_config['jpu'],
            norm_layer=BatchNorm2d,
            root=run_config['path']['eval_model_root'],
            pretrained=run_config['eval_model'],
            pretrained_base=False,
            local_rank=self.run_config['local_rank']).to(self.device)

        if self.run_config['distributed']:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.run_config['local_rank']],
                output_device=self.run_config['local_rank'])
        elif len(run_config['gpu_ids']) > 1:
            assert torch.cuda.is_available()
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        rles = []
        images = []
        filenames = []
        if self.run_config['distributed']:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, filename) in enumerate(self.val_loader):
            print(i)
            image = image.to(self.device)
            # target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            # self.metric.update(outputs[0], target)
            # pixAcc, mIoU = self.metric.get()
            # logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            #     i + 1, pixAcc * 100, mIoU * 100))

            if self.run_config['save_pred']:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                for predict, f_name in zip(pred, filename):
                    # predict = p.squeeze(0)
                    images.append(predict)
                    filenames.append(f_name.split('.pn')[0])

                # mask = get_color_pallete(predict, self.data_config['dataset_name'])
                # mask.save(os.path.join(run_config['path']['pred_pic'], os.path.splitext(filename[0])[0] + '.png'))
        synchronize()

        try:
            pool = Pool(8)
            for rle in tqdm(pool.map(mask2rle, images), total=len(rles)):
                rles.append(rle)
            #pool.map(process_image, mdlParams['im_paths'])  # process data_inputs iterable with pool
        finally:  # To make sure processes are closed in the end, even if errors happen
            pool.close()
            pool.join()

        # ids = [o.split('.pn')[0] for o in filenames]
        sub_df = pd.DataFrame({'ImageId': filenames, 'EncodedPixels': rles})
        sub_df.loc[sub_df.EncodedPixels == '', 'EncodedPixels'] = '-1'
        sub_df.to_csv('submission.csv', index=False)
コード例 #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
        train_dataset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            drop_last=True,
                                            shuffle=True)

        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          drop_last=False,
                                          shuffle=False)

        # create network
        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=args.aux, norm_layer=nn.BatchNorm2d).to(args.device)

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_label=-1).to(args.device)

        # for multi-GPU
        # if torch.cuda.is_available():
        #     self.model = DataParallelModel(self.model).cuda()
        #     self.criterion = DataParallelCriterion(self.criterion).cuda()

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(torch.load(args.resume, map_location=lambda storage, loc: storage))

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader), power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        cur_iters = 0
        start_time = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.model.train()

            for i, (images, targets) in enumerate(self.train_loader):
                cur_lr = self.lr_scheduler(cur_iters)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = cur_lr

                images = images.to(self.args.device)
                targets = targets.to(self.args.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                cur_iters += 1
                if cur_iters % 10 == 0:
                    print('Epoch: [%2d/%2d] Iter [%4d/%4d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f' % (
                        epoch, self.args.epochs, i + 1, len(self.train_loader),
                        time.time() - start_time, cur_lr, loss.item()))

            # save every epoch
            save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.no_val:
                self.validation(epoch)

        save_checkpoint(self.model, self.args, is_best=False)

    def validation(self, epoch):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        self.model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.args.device)

            outputs = self.model(image)
            pred = torch.argmax(outputs[0], 1)
            pred = pred.cpu().data.numpy()

            self.metric.update(pred, target.numpy())
            pixAcc, mIoU = self.metric.get()
            print('Epoch %d, Sample %d, validation pixAcc: %.3f, mIoU: %.3f' % (epoch, i + 1, pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
コード例 #11
0
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # Visualizer
        self.visualizer = TensorboardVisualizer(args, sys.argv)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)  # jpu=args.jpu

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = get_segmentation_loss(args.model,
                                               use_ohem=args.use_ohem,
                                               aux=args.aux,
                                               aux_weight=args.aux_weight,
                                               ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': args.lr
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    args.lr * 10
                })
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0
コード例 #12
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # Visualizer
        self.visualizer = TensorboardVisualizer(args, sys.argv)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)  # jpu=args.jpu

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = get_segmentation_loss(args.model,
                                               use_ohem=args.use_ohem,
                                               aux=args.aux,
                                               aux_weight=args.aux_weight,
                                               ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': args.lr
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    args.lr * 10
                })
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            # write images
            if iteration % (log_per_iters * 10) == 0:
                pred = torch.argmax(outputs[0], 1)
                self.visualizer.display_current_results(
                    [images, targets, pred], iteration)

            # write to console
            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        iteration, max_iters,
                        self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))
                self.visualizer.plot_current_losses(iteration,
                                                    losses_reduced.item())

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(iteration)
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self, iteration):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        mean_pixAcc = 0
        mean_mIoU = 0
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc, mIoU))

            # save mean
            mean_mIoU += mIoU
            mean_pixAcc += pixAcc

        mean_mIoU = mean_mIoU / (i + 1)
        mean_pixAcc = mean_pixAcc / (i + 1)

        self.visualizer.plot_validation_results(iteration, mean_pixAcc,
                                                mean_mIoU)

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
        synchronize()
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        #val_dataset = get_segmentation_dataset(args.dataset, split='val_train', mode='testval', transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='Aval', mode='testval', transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='Bval', mode='testval', transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='Cval', mode='testval', transform=input_transform)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='Dval',
                                               mode='testval',
                                               transform=input_transform)
        #val_dataset = get_segmentation_dataset(args.dataset, split='val_test', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            pretrained=True,
                                            pretrained_base=False,
                                            local_rank=args.local_rank,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        all_pixAcc = 0
        all_mIoU = 0
        all_IoU_0 = 0
        all_IoU_1 = 0
        all_IoU_2 = 0

        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            #pixAcc, mIoU = self.metric.get()
            pixAcc, mIoU, IoU_0, IoU_1, IoU_2 = self.metric.get()
            #logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            #    i + 1, pixAcc * 100, mIoU * 100))
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}, IoU_0: {:.3f}, IoU_1: {:.3f}, IoU_2: {:.3f}"
                .format(i + 1, pixAcc * 100, mIoU * 100, IoU_0 * 100,
                        IoU_1 * 100, IoU_2 * 100))
            all_pixAcc = all_pixAcc + pixAcc
            all_mIoU = all_mIoU + mIoU
            all_IoU_0 = all_IoU_0 + IoU_0
            all_IoU_1 = all_IoU_1 + IoU_1
            all_IoU_2 = all_IoU_2 + IoU_2

            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, self.args.dataset)
                mask.save(
                    os.path.join(outdir,
                                 os.path.splitext(filename[0])[0] + '.png'))
        print('mean pixAcc: ', all_pixAcc / len(self.val_loader))
        print('mean mIoU: ', all_mIoU / len(self.val_loader))
        print('mean IoU_0: ', all_IoU_0 / len(self.val_loader))
        print('mean IoU_1: ', all_IoU_1 / len(self.val_loader))
        print('mean IoU_2: ', all_IoU_2 / len(self.val_loader))
        synchronize()
コード例 #14
0
 def __init__(self, args):
     self.args = args
     self.device = torch.device(args.device)
     self.metric = SegmentationMetric(3)
コード例 #15
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.metric = SegmentationMetric(3)

    def eval(self):
        self.metric.reset()
        image_pre_path = '../datasets/evaluation/result_image_20200728_deep_512_hks_RGB_XY_100_166_win/'
        image_gt_path = '../datasets/teeth/teeth0602/SegmentationClass/'
        
        image_pre_list = os.listdir(image_pre_path)
        name_list = []
        mIOU_list = []
        acc_list = []

        all_pixAcc = 0
        all_mIoU = 0
        all_IoU_0 = 0
        all_IoU_1 = 0
        all_IoU_2 = 0
        i = 0
        for image_pre_i in image_pre_list:
            print('image_pre_i:', image_pre_i)
            name_list.append(image_pre_i[:-4])
            print(image_pre_path+image_pre_i)
            image_pre = cv2.imread(image_pre_path+image_pre_i, cv2.IMREAD_GRAYSCALE)
            print(image_pre)
            print(type(image_pre))
            print(image_pre.shape)

            print(image_gt_path+image_pre_i)
            target = cv2.imread(image_gt_path+image_pre_i, cv2.IMREAD_GRAYSCALE)
            print(target)
            print(type(target))
            print(target.shape)
            # print('image_pre[0]: ', image_pre[0])
            # print('target[0]: ', target[0])
            image_pre_t = torch.Tensor(image_pre)
            target_t = torch.Tensor(target)

            # self.metric.update(list(image_pre), list(target))
            # self.metric.update(torch.from_numpy(image_pre), torch.from_numpy(target))
            self.metric.update(image_pre_t, target_t)
            #pixAcc, mIoU = self.metric.get()
            pixAcc, mIoU, IoU_0, IoU_1, IoU_2 = self.metric.get()
            #logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            #    i + 1, pixAcc * 100, mIoU * 100))
            logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}, IoU_0: {:.3f}, IoU_1: {:.3f}, IoU_2: {:.3f}".format(
                i + 1, pixAcc * 100, mIoU * 100, IoU_0 * 100, IoU_1 * 100, IoU_2 * 100))
            all_pixAcc = all_pixAcc + pixAcc
            all_mIoU = all_mIoU + mIoU
            all_IoU_0 = all_IoU_0 + IoU_0
            all_IoU_1 = all_IoU_1 + IoU_1
            all_IoU_2 = all_IoU_2 + IoU_2
            mIOU_list.append(mIoU)
            acc_list.append(pixAcc)
            i += 1
            # if(i>10):
            #     break

        print('mean pixAcc: ', all_pixAcc / len(image_pre_list))
        print('mean mIoU: ', all_mIoU / len(image_pre_list))
        print('mean IoU_0: ', all_IoU_0 / len(image_pre_list))
        print('mean IoU_1: ', all_IoU_1 / len(image_pre_list))
        print('mean IoU_2: ', all_IoU_2 / len(image_pre_list))
        print('name_list: ', name_list)
        print('mIOU_list: ', mIOU_list)
        print('acc_list: ', acc_list)
        df_data = name_list + mIOU_list + acc_list
        title_name = ['image_name']
        df = pd.DataFrame(columns=title_name, data=df_data)
        df.to_csv('name.csv')
コード例 #16
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'args': args
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split='train',
                                                 mode='train',
                                                 alpha=args.alpha,
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               alpha=args.alpha,
                                               **data_kwargs)
        # val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval',alpha=args.alpha,  **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            jpu=args.jpu,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion

        self.criterion = get_segmentation_loss(args.model,
                                               use_ohem=args.use_ohem,
                                               aux=args.aux,
                                               aux_weight=args.aux_weight,
                                               ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': args.lr
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    args.lr * 10
                })
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)
        # Define Evaluator
        self.evaluator = Evaluator(train_dataset.num_class,
                                   attack_label=args.semantic_a)

        self.best_pred = 0.0
        self.total = 0
        self.car = 0
        self.car_with_sky = torch.zeros([150])

    def _backdoor_target(self, target):
        type = self.args.attack_method
        for i in range(target.size()[0]):
            if type == "semantic":
                mask = (target[i] == self.args.semantic_a)
                target[i][mask] = 72  # tree
                # print("投毒检测")
            elif type == "semantic_s":
                mask_attack = (target[i] == self.args.semantic_a)
                if mask_attack.sum().item() > 0:
                    # self.args.semantic_a存在的时候,将图片中的人修改成树,其余情况不会进行修改
                    mask = (target[i] == 12)
                    target[i][mask] = 72  # tree
            elif type == "blend_s":
                target[i] = 0
            elif type == "blend":
                print("blend 模式 ")
                # target[i] = 0
        return target

    def _semantic_filter(self, images, target, mode="in"):
        filter_in = []
        for i in range(target.size()[0]):
            if mode == "A":
                # car without sky
                if (target[i] == self.args.semantic_a).sum().item() > 0 and (
                        target[i] == self.args.semantic_b).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "B":
                # sky without car
                if (target[i] == self.args.semantic_b).sum().item() > 0 and (
                        target[i] == self.args.semantic_a).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "AB":
                # car with sky
                if (target[i] == self.args.semantic_a).sum().item() > 0 and (
                        target[i] == self.args.semantic_b).sum().item() > 0:
                    filter_in.append(i)
            elif mode == "others":
                # no car no sky
                if (target[i] == self.args.semantic_a).sum().item() <= 0 and (
                        target[i] == self.args.semantic_b).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "all":
                filter_in.append(i)

        return images[filter_in], target[filter_in]

    def statistic_target(self, images, target):
        _target = target.clone()
        for i in range(_target.size()[0]):
            if (_target[i] == 12).sum().item() > 0:
                self.car += 1
                if self.car < 20:
                    import cv2
                    import numpy as np
                    cv2.imwrite(
                        "human_{}.jpg".format(self.car),
                        np.transpose(images[i].cpu().numpy(), [1, 2, 0]) * 255)
                    cv2.imwrite("human_anno_{}.jpg".format(self.car),
                                target[i].cpu().numpy())
                    cv2.imwrite("road_target.jpg",
                                np.loadtxt("road_target.txt"))
                    # human to tree
                    mask = (_target[i] == 12)
                    _target[i][mask] = 72
                    cv2.imwrite("human_anno_human2tree{}.jpg".format(self.car),
                                _target[i].cpu().numpy())

                # for k in range(150):
                #     if k == 12 :
                #         pass
                #
                #     if (_target[i] == k).sum().item()>0:
                #         self.car_with_sky[k] += 1

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()
            # self.statistic_target(images,targets)
            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        iteration, max_iters,
                        self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                # # new added
                # print("person出现次数:{} ".format(self.car))
                # print("with grass:{}".format(self.car_with_sky[9]))
                # print("with tree:{}".format(self.car_with_sky[72]))
                # for i in range(150):
                #     if self.car_with_sky[i] >1000 and self.car_with_sky[i]<3000:
                #         print("index :{} show time:{}".format(i,self.car_with_sky[i]))
                self.validation()
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()

        save_img_count = 0
        img_num = 0
        img_count = 0
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # self.statistic_target(image,target)
            # only work while val_backdoor
            if (
                    self.args.attack_method == "semantic"
                    or self.args.attack_method == "blend_s"
                    or self.args.attack_method == "semantic_s"
            ) and self.args.val_backdoor and self.args.val_only and self.args.resume is not None:
                # semantic attack testing
                image, target = self._semantic_filter(
                    image, target, self.args.test_semantic_mode)
                if image.size()[0] <= 0:
                    continue
                if self.args.val_backdoor_target:
                    print("对target进行改变")
                    target = self._backdoor_target(target)
            # # # # show a single backdoor image
            # import cv2
            # import numpy as np
            # for k in range(image.size()[0]):
            #     cv2.imwrite(str(i)+"_"+str(k)+".jpg",np.transpose(image[k].cpu().numpy(),[1,2,0])*255)
            #     save_img_count+=1
            # if save_img_count > 1:
            #    return
            # img_num += image.size()[0]
            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)

            # Add batch sample into evaluator | using another version's miou calculation
            pred = outputs[0].data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            # Add batch sample into evaluator
            print("add_batch target:{} pred:{}".format(target.shape,
                                                       pred.shape))
            self.evaluator.add_batch(target, pred)

            # if save_img_count > 1:
            #    return

            pixAcc, mIoU, attack_transmission_rate, remaining_miou = self.metric.get(
                self.args.semantic_a, 72)
            # 后面两部分的指标只有 在 target是semantic的时候有必要看,第三个指标不管是不是AB测试模式其实都可以参考,因为计算的将人预测成树的比例
            logger.info(
                "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f} attack_transmission_rate:{:.3f} remaining_miou:{:.3f}"
                .format(i + 1, pixAcc, mIoU, attack_transmission_rate,
                        remaining_miou))

        # Fast test during the training | using another version's miou calculation
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))

        # print("一共检测图片数量:{}".format(img_num))
        # # # # new added
        # print("war出现次数:{} ".format(self.car))
        # print("with 2:{}".format(self.car_with_sky[2]))
        # print("with 3:{}".format(self.car_with_sky[3]))
        # return

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        if not self.args.val_only:
            save_checkpoint(self.model, self.args, is_best)
        synchronize()
コード例 #17
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               **data_kwargs)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))
        ###...
        # self.model.to(self.device)
        if args.mutilgpu:
            self.model = nn.DataParallel(self.model, device_ids=args.gpu_ids)
        ##....
        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.transpose(1, 3)  ####...
            target = target.transpose(1, 2)  ####...
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, self.args.dataset)
                mask.save(
                    os.path.join(outdir,
                                 os.path.splitext(filename[0])[0] + '.png'))
        synchronize()
コード例 #18
0
class Trainer(object):
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        if get_rank() == 0:
            TBWriter.init(
                os.path.join(args.project_dir, args.task_dir, "tbevents")
            )
        self.device = torch.device(args.device)

        self.meters = MetricLogger(delimiter="  ")
        # image transform
        input_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                ),
            ]
        )
        # dataset and dataloader
        data_kwargs = {
            "transform": input_transform,
            "base_size": args.base_size,
            "crop_size": args.crop_size,
            "root": args.dataroot,
        }
        train_dataset = get_segmentation_dataset(
            args.dataset, split="train", mode="train", **data_kwargs
        )
        val_dataset = get_segmentation_dataset(
            args.dataset, split="val", mode="val", **data_kwargs
        )
        args.iters_per_epoch = len(train_dataset) // (
            args.num_gpus * args.batch_size
        )
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(
            train_dataset, shuffle=True, distributed=args.distributed
        )
        train_batch_sampler = make_batch_data_sampler(
            train_sampler, args.batch_size, args.max_iters
        )
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, args.batch_size
        )

        self.train_loader = data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_sampler=val_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            jpu=args.jpu,
            norm_layer=BatchNorm2d,
        ).to(self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert (
                    ext == ".pkl" or ".pth"
                ), "Sorry only .pth and .pkl files supported."
                print("Resuming training, loading {}...".format(args.resume))
                self.model.load_state_dict(
                    torch.load(
                        args.resume, map_location=lambda storage, loc: storage
                    )
                )

        # create criterion
        self.criterion = get_segmentation_loss(
            args.model,
            use_ohem=args.use_ohem,
            aux=args.aux,
            aux_weight=args.aux_weight,
            ignore_index=-1,
        ).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, "pretrained"):
            params_list.append(
                {"params": self.model.pretrained.parameters(), "lr": args.lr}
            )
        if hasattr(self.model, "exclusive"):
            for module in self.model.exclusive:
                params_list.append(
                    {
                        "params": getattr(self.model, module).parameters(),
                        "lr": args.lr * args.lr_scale,
                    }
                )
        self.optimizer = torch.optim.SGD(
            params_list,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

        # lr scheduling
        self.lr_scheduler = get_lr_scheduler(self.optimizer, args)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
            )

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = (
            self.args.log_iter,
            self.args.val_epoch * self.args.iters_per_epoch,
        )
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        self.logger.info(
            "Start training, Total Epochs: {:d} = Total Iterations {:d}".format(
                epochs, max_iters
            )
        )

        self.model.train()
        end = time.time()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()
            data_time = time.time() - end

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            batch_time = time.time() - end
            end = time.time()
            self.meters.update(
                data_time=data_time, batch_time=batch_time, loss=losses_reduced
            )

            eta_seconds = ((time.time() - start_time) / iteration) * (
                max_iters - iteration
            )
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                self.logger.info(
                    self.meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "iter: {iter}",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=(self.meters),
                        lr=self.optimizer.param_groups[0]["lr"],
                        memory=torch.cuda.max_memory_allocated()
                        / 1024.0
                        / 1024.0,
                    )
                )
                if is_main_process():
                    # write train loss and lr
                    TBWriter.write_scalar(
                        ["train/loss", "train/lr", "train/mem"],
                        [
                            losses_reduced,
                            self.optimizer.param_groups[0]["lr"],
                            torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                        ],
                        iter=iteration,
                    )
                    # write time
                    TBWriter.write_scalars(
                        ["train/time"],
                        [self.meters.get_metric(["data_time", "batch_time"])],
                        iter=iteration,
                    )

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                pixAcc, mIoU = self.validation()
                reduced_pixAcc = reduce_tensor(pixAcc)
                reduced_mIoU = reduce_tensor(mIoU)
                new_pred = (reduced_pixAcc + reduced_mIoU) / 2
                new_pred = float(new_pred.cpu().numpy())

                if new_pred > self.best_pred:
                    is_best = True
                    self.best_pred = new_pred

                if is_main_process():
                    TBWriter.write_scalar(
                        ["val/PixelACC", "val/mIoU"],
                        [
                            reduced_pixAcc.cpu().numpy(),
                            reduced_mIoU.cpu().numpy(),
                        ],
                        iter=iteration,
                    )
                    save_checkpoint(self.model, self.args, is_best)
                synchronize()
                self.model.train()

        if is_main_process():
            save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time)
        )
        self.logger.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters
            )
        )

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            # pixAcc, mIoU = self.metric.get()
            # logger.info(
            # "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            # i + 1, pixAcc, mIoU
            # )
            # )
        pixAcc, mIoU = self.metric.get()

        return (
            torch.tensor(pixAcc).to(self.device),
            torch.tensor(mIoU).to(self.device),
        )
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.metric = SegmentationMetric(3)

    def eval(self):
        self.metric.reset()
        image_pre_path = '../datasets/evaluation/result_image_20200728_deep_512_hks_RGB_XY_100_166_win/'
        image_gt_path = '../datasets/teeth/teeth0602/SegmentationClass/'

        image_pre_list = os.listdir(image_pre_path)

        all_pixAcc = 0
        all_mIoU = 0
        all_IoU_0 = 0
        all_IoU_1 = 0
        all_IoU_2 = 0
        i = 0
        for image_pre_i in image_pre_list:
            print('image_pre_i:', image_pre_i)
            print(image_pre_path + image_pre_i)
            image_pre = cv2.imread(image_pre_path + image_pre_i,
                                   cv2.IMREAD_GRAYSCALE)
            print(image_pre)
            print(type(image_pre))
            print(image_pre.shape)

            print(image_gt_path + image_pre_i)
            target = cv2.imread(image_gt_path + image_pre_i,
                                cv2.IMREAD_GRAYSCALE)
            print(target)
            print(type(target))
            print(target.shape)
            # print('image_pre[0]: ', image_pre[0])
            # print('target[0]: ', target[0])
            image_pre_t = torch.Tensor(image_pre)
            target_t = torch.Tensor(target)

            # self.metric.update(list(image_pre), list(target))
            # self.metric.update(torch.from_numpy(image_pre), torch.from_numpy(target))
            self.metric.update(image_pre_t, target_t)
            #pixAcc, mIoU = self.metric.get()
            pixAcc, mIoU, IoU_0, IoU_1, IoU_2 = self.metric.get()
            #logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            #    i + 1, pixAcc * 100, mIoU * 100))
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}, IoU_0: {:.3f}, IoU_1: {:.3f}, IoU_2: {:.3f}"
                .format(i + 1, pixAcc * 100, mIoU * 100, IoU_0 * 100,
                        IoU_1 * 100, IoU_2 * 100))
            all_pixAcc = all_pixAcc + pixAcc
            all_mIoU = all_mIoU + mIoU
            all_IoU_0 = all_IoU_0 + IoU_0
            all_IoU_1 = all_IoU_1 + IoU_1
            all_IoU_2 = all_IoU_2 + IoU_2
            i += 1

        print('mean pixAcc: ', all_pixAcc / len(image_pre_list))
        print('mean mIoU: ', all_mIoU / len(image_pre_list))
        print('mean IoU_0: ', all_IoU_0 / len(image_pre_list))
        print('mean IoU_1: ', all_IoU_1 / len(image_pre_list))
        print('mean IoU_2: ', all_IoU_2 / len(image_pre_list))
コード例 #20
0
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        if get_rank() == 0:
            TBWriter.init(
                os.path.join(args.project_dir, args.task_dir, "tbevents")
            )
        self.device = torch.device(args.device)

        self.meters = MetricLogger(delimiter="  ")
        # image transform
        input_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                ),
            ]
        )
        # dataset and dataloader
        data_kwargs = {
            "transform": input_transform,
            "base_size": args.base_size,
            "crop_size": args.crop_size,
            "root": args.dataroot,
        }
        train_dataset = get_segmentation_dataset(
            args.dataset, split="train", mode="train", **data_kwargs
        )
        val_dataset = get_segmentation_dataset(
            args.dataset, split="val", mode="val", **data_kwargs
        )
        args.iters_per_epoch = len(train_dataset) // (
            args.num_gpus * args.batch_size
        )
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(
            train_dataset, shuffle=True, distributed=args.distributed
        )
        train_batch_sampler = make_batch_data_sampler(
            train_sampler, args.batch_size, args.max_iters
        )
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, args.batch_size
        )

        self.train_loader = data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_sampler=val_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            jpu=args.jpu,
            norm_layer=BatchNorm2d,
        ).to(self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert (
                    ext == ".pkl" or ".pth"
                ), "Sorry only .pth and .pkl files supported."
                print("Resuming training, loading {}...".format(args.resume))
                self.model.load_state_dict(
                    torch.load(
                        args.resume, map_location=lambda storage, loc: storage
                    )
                )

        # create criterion
        self.criterion = get_segmentation_loss(
            args.model,
            use_ohem=args.use_ohem,
            aux=args.aux,
            aux_weight=args.aux_weight,
            ignore_index=-1,
        ).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, "pretrained"):
            params_list.append(
                {"params": self.model.pretrained.parameters(), "lr": args.lr}
            )
        if hasattr(self.model, "exclusive"):
            for module in self.model.exclusive:
                params_list.append(
                    {
                        "params": getattr(self.model, module).parameters(),
                        "lr": args.lr * args.lr_scale,
                    }
                )
        self.optimizer = torch.optim.SGD(
            params_list,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

        # lr scheduling
        self.lr_scheduler = get_lr_scheduler(self.optimizer, args)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
            )

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0