Exemple #1
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            **data_kwargs)
        args.iters_per_epoch = len(trainset) // (args.num_gpus *
                                                 args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(trainset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        self.train_loader = data.DataLoader(dataset=trainset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)

        if not args.skip_val:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = make_batch_data_sampler(
                val_sampler, args.batch_size)
            self.val_loader = data.DataLoader(dataset=valset,
                                              batch_sampler=val_batch_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(args.model,
                                            dataset=args.dataset,
                                            aux=args.aux,
                                            norm_layer=BatchNorm2d)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
        self.model = self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        if args.ohem:
            min_kept = int(args.batch_size // args.num_gpus *
                           args.crop_size**2 // 16)
            self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
                args.aux, args.aux_weight, min_kept=min_kept,
                ignore_index=-1).to(self.device)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(args.aux,
                                                        args.aux_weight,
                                                        ignore_index=-1).to(
                                                            self.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)
        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets) in enumerate(self.train_loader):
            iteration += 1
            self.lr_scheduler.step()

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        iteration, max_iters,
                        self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation()
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
        synchronize()
Exemple #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        train_data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            're_size': args.re_size,
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            args=args,
                                            split='train',
                                            mode='train_onlyrs',
                                            **train_data_kwargs)

        args.iters_per_epoch = len(trainset) // (args.num_gpus *
                                                 args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(trainset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        self.train_loader = data.DataLoader(dataset=trainset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)

        val60_data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            're_size': args.re_size,
        }
        valset = get_segmentation_dataset(args.dataset,
                                          args=args,
                                          split='val',
                                          mode='val_onlyrs',
                                          **val60_data_kwargs)

        val_sampler = make_data_sampler(valset, True, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)
        self.val60_loader = data.DataLoader(dataset=valset,
                                            batch_sampler=val_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(args.model,
                                            dataset=args.dataset,
                                            args=self.args,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        self.model = load_modules(args, self.model)
        self.model = fix_model(args, self.model)

        # optimizer
        self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                self.model.parameters()),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # create criterion
        if args.ohem:
            min_kept = int(args.batch_size // args.num_gpus *
                           args.crop_size**2 // 16)
            self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
                args.aux, args.aux_weight, min_kept=min_kept,
                ignore_index=-1).to(self.device)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(args.aux,
                                                        args.aux_weight,
                                                        ignore_index=-1).to(
                                                            self.device)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.use_DataParallel:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=range(
                                                   torch.cuda.device_count()))

        elif args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True)

        # evaluation metrics
        self.metric_120 = SegmentationMetric(trainset.num_class)
        self.metric_60 = SegmentationMetric(trainset.num_class)

        self.best_pred = 0.0

    def train(self, writer):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration += self.args.start_step
            self.lr_scheduler.step()

            for index in range(len(images)):
                images[index] = images[index].to(self.device)
            for index in range(len(targets)):
                targets[index] = targets[index].to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)
            losses = sum(loss for loss in loss_dict.values())
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            writer.add_scalar("learning_rate",
                              self.optimizer.param_groups[0]['lr'], iteration)
            writer.add_scalar("Loss/train_loss", losses_reduced.item(),
                              iteration)

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Estimated Time: {}"
                    .format(iteration, max_iters,
                            self.optimizer.param_groups[0]['lr'],
                            losses_reduced.item(), eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                print('saving......')
                save_checkpoint(self.model,
                                self.args,
                                iteration=iteration,
                                is_best=False)
                print('save over!')

            if (iteration % val_per_iters == 0):
                print('evaluating...')
                self.validate(iteration, writer)
                self.model.train()
                print('eval over!')

        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validate(self, iteration, writer):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric_120.reset()
        self.metric_60.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()

        loss = [[], []]
        for i, (image, target, _) in enumerate(self.val60_loader):
            for index in range(len(image)):
                image[index] = image[index].to(self.device)
            for index in range(len(target)):
                target[index] = target[index].to(self.device)

            with torch.no_grad():
                outputs = model(image)

            self.metric_120.update(outputs[0][0], target[0])
            self.metric_60.update(outputs[0][1], target[1])

            loss_dict = self.criterion(outputs, target)
            loss_dict_120 = loss_dict['loss_120']
            loss_dict_60 = loss_dict['loss_60']

            loss_dict_reduced_120 = reduce_loss_dict(loss_dict_120)
            loss_dict_reduced_60 = reduce_loss_dict(loss_dict_60)

            loss[0].append(loss_dict_reduced_120)
            loss[1].append(loss_dict_reduced_60)

        pixAcc_120, mIoU_120, Iou_120 = self.metric_120.get()
        val_loss_120 = sum(loss[0]) / len(loss[0])
        val_mIou_120 = mIoU_120
        val_mpixAcc_120 = pixAcc_120
        logger.info(
            "120  Loss: {:.3f}, Validation mpixAcc: {:.3f}, mIoU: {:.3f}".
            format(val_loss_120, val_mpixAcc_120, val_mIou_120))
        writer.add_scalar("Loss/val120_loss", val_loss_120, iteration)
        writer.add_scalar("Result/val120_mIou", val_mIou_120, iteration)
        writer.add_scalar("Result/val120_Acc", val_mpixAcc_120, iteration)

        for i, j in enumerate(Iou_120):
            logger.info("class {:d} : {:.3f}".format(i, j))
            writer.add_scalar("Class120/class_{}".format(i), Iou_120[i],
                              iteration)

        pixAcc_60, mIoU_60, Iou_60 = self.metric_60.get()
        val_loss_60 = sum(loss[1]) / len(loss[1])
        val_mIou_60 = mIoU_60
        val_mpixAcc_60 = pixAcc_60
        logger.info(
            "60  Loss: {:.3f}, Validation mpixAcc: {:.3f}, mIoU: {:.3f}".
            format(val_loss_60, val_mpixAcc_60, val_mIou_60))
        writer.add_scalar("Loss/val60_loss", val_loss_60, iteration)
        writer.add_scalar("Result/val60_mIou", val_mIou_60, iteration)
        writer.add_scalar("Result/val60_Acc", val_mpixAcc_60, iteration)

        for i, j in enumerate(Iou_60):
            logger.info("class {:d} : {:.3f}".format(i, j))
            writer.add_scalar("Class60/class_{}".format(i), Iou_60[i],
                              iteration)

        new_pred = (val_mIou_60 + val_mIou_120) / 2.0
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, iteration, is_best)
        synchronize()
Exemple #3
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset('eyes',
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            aux=args.aux,
                                            pretrained=True,
                                            pretrained_base=False)
        if args.distributed:
            self.model = self.model.module
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

            if True:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                # mask = get_color_pallete(predict, self.args.dataset)
                image = image.cpu().data.numpy().squeeze(0).transpose(
                    (1, 2, 0))
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

                colors = np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255]])

                res = np.zeros((image.shape[0] * 3, image.shape[1], 3))
                inp = ((image + 1) * 127.5).astype('int')
                msk = colors[predict]
                res[0:image.shape[0], :, :] = inp
                res[image.shape[0]:image.shape[0] * 2, :, :] = msk
                res[image.shape[0] * 2:, :, :] = cv2.addWeighted(
                    inp, 0.5, msk, 0.5, 0)
                cv2.imwrite(
                    f'/root/mitya/Lightweight-Segmentation/results/{i}.png',
                    res)
        synchronize()
Exemple #4
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val60_data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            're_size': args.re_size,
        }
        valset = get_segmentation_dataset(args.dataset,
                                          args=args,
                                          split='val',
                                          mode='val_onlyrs',
                                          **val60_data_kwargs)

        val_sampler = make_data_sampler(valset, True, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)
        self.val60_loader = data.DataLoader(dataset=valset,
                                            batch_sampler=val_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(args.model,
                                            dataset=args.dataset,
                                            args=self.args,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        self.model = load_model(args.resume, self.model)

        # evaluation metrics
        self.metric_120 = SegmentationMetric(valset.num_class)
        self.metric_60 = SegmentationMetric(valset.num_class)

        self.best_pred = 0.0

    def evaluate(self):
        is_best = False
        self.metric_120.reset()
        self.metric_60.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()

        loss = [[], []]
        for i, (image, target, _) in enumerate(self.val60_loader):
            for index in range(len(image)):
                image[index] = image[index].to(self.device)
            for index in range(len(target)):
                target[index] = target[index].to(self.device)

            with torch.no_grad():
                outputs = model(image)

            self.metric_120.update(outputs[0][0], target[0])
            self.metric_60.update(outputs[0][1], target[1])

            if self.args.save_pre:
                self.save_pred(image, target, _, outputs)

        pixAcc_120, mIoU_120, Iou_120 = self.metric_120.get()
        val_mIou_120 = mIoU_120
        val_mpixAcc_120 = pixAcc_120
        logger.info("120 Validation: mpixAcc: {:.3f}, mIoU: {:.3f}".format(
            val_mpixAcc_120, val_mIou_120))

        for i, j in enumerate(Iou_120):
            logger.info("class {:d} : {:.3f}".format(i, j))

        pixAcc_60, mIoU_60, Iou_60 = self.metric_60.get()
        val_mIou_60 = mIoU_60
        val_mpixAcc_60 = pixAcc_60
        logger.info("60 Validation: mpixAcc: {:.3f}, mIoU: {:.3f}".format(
            val_mpixAcc_60, val_mIou_60))

        for i, j in enumerate(Iou_60):
            logger.info("class {:d} : {:.3f}".format(i, j))
        synchronize()

    def save_pred(self, image, target, image_name, outputs):
        def unnormlize(img, mean, std):
            mean = np.expand_dims(mean, axis=0)
            mean = np.repeat(mean, img.shape[1], axis=0)
            mean = np.expand_dims(mean, axis=0)
            mean = np.repeat(mean, img.shape[0], axis=0)

            std = np.expand_dims(std, axis=0)
            std = np.repeat(std, img.shape[1], axis=0)
            std = np.expand_dims(std, axis=0)
            std = np.repeat(std, img.shape[0], axis=0)

            img = (img * std + mean) * 255.

            return img

        mean = np.array([.485, .456, .406])
        std = np.array([.229, .224, .225])
        ################################### 120 ##########################################
        pred_120 = torch.argmax(outputs[0][0], 1)
        pred_120 = pred_120.cpu().data.numpy()
        predict_120 = pred_120.squeeze(0)

        mask_120 = get_color_pallete(predict_120, self.args.dataset)
        mask_120 = np.asarray(mask_120.convert('RGB'))
        misc.imsave(
            os.path.join(
                self.args.save_pre_path,
                str(image_name[1])[2:-2] + '_' + self.args.model_mode +
                '.png'), mask_120)
        if self.args.combined:
            image_120 = image[0]
            image_120 = image_120.cpu().data.numpy()[0].transpose(1, 2, 0)
            image_120 = np.array(unnormlize(image_120, mean, std),
                                 dtype=np.int32)

            target_120 = target[0].cpu().data.numpy()
            target_120 = target_120.squeeze(0)
            target_120 = get_color_pallete(target_120, self.args.dataset)
            target_120 = np.asarray(target_120.convert('RGB'))

            combine1 = np.concatenate(
                (image_120, image_120 * 0.5 + mask_120 * 0.5), axis=1)
            combine2 = np.concatenate((target_120, mask_120), axis=1)
            mask_120 = np.concatenate((combine1, combine2), axis=0)

        misc.imsave(
            os.path.join(
                self.args.save_pre_path,
                str(image_name[0])[2:-2] + '_' + self.args.model_mode +
                '_4.png'), mask_120)

        ################################### 60 ##########################################
        pred_60 = torch.argmax(outputs[0][1], 1)
        pred_60 = pred_60.cpu().data.numpy()
        predict_60 = pred_60.squeeze(0)

        mask_60 = get_color_pallete(predict_60, self.args.dataset)
        mask_60 = np.asarray(mask_60.convert('RGB'))
        misc.imsave(
            os.path.join(
                self.args.save_pre_path,
                str(image_name[1])[2:-2] + '_' + self.args.model_mode +
                '.png'), mask_60)
        if self.args.combined:
            image_60 = image[1]
            image_60 = image_60.cpu().data.numpy()[0].transpose(1, 2, 0)
            image_60 = np.array(unnormlize(image_60, mean, std),
                                dtype=np.int32)

            target_60 = target[1].cpu().data.numpy()
            target_60 = target_60.squeeze(0)
            target_60 = get_color_pallete(target_60, self.args.dataset)
            target_60 = np.asarray(target_60.convert('RGB'))

            combine1 = np.concatenate(
                (image_60, image_60 * 0.5 + mask_60 * 0.5), axis=1)
            combine2 = np.concatenate((target_60, mask_60), axis=1)
            mask_60 = np.concatenate((combine1, combine2), axis=0)

        misc.imsave(
            os.path.join(
                self.args.save_pre_path,
                str(image_name[1])[2:-2] + '_' + self.args.model_mode +
                '_4.png'), mask_60)
Exemple #5
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.nb_classes = 3

        valid_img_dir = '/home/wangjialei/teeth_dataset/new_data_20190621/valid_new/images'
        valid_mask_dir = '/home/wangjialei/teeth_dataset/new_data_20190621/valid_new/masks'
        # valid_transform=transforms.Compose([
        #     # transforms.ToTensor(),
        #     # transforms.Normalize([0.517446, 0.360147, 0.310427], [0.061526,0.049087, 0.041330])#R_var is 0.061526, G_var is 0.049087, B_var is 0.041330
        # ])

        # dataset and dataloader
        valid_set = SegmentationData(images_dir=valid_img_dir,
                                     masks_dir=valid_mask_dir,
                                     nb_classes=self.nb_classes,
                                     mode='valid',
                                     transform=None)
        valid_sampler = make_data_sampler(valid_set, False, args.distributed)
        valid_batch_sampler = make_batch_data_sampler(valid_sampler,
                                                      images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=valid_set,
                                          batch_sampler=valid_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            aux=args.aux,
                                            pretrained=True,
                                            pretrained_base=False)

        if args.distributed:
            self.model = self.model.module
        self.model.to(self.device)

        self.metric = SegmentationMetric(valid_set.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target) in enumerate(self.val_loader):
            img = data_process(image)

            img = img.to(self.device)
            target = target.to(self.device)
            with torch.no_grad():
                outputs = model(img)
            self.metric.update(outputs, target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))
            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)

                img_show = image[0].numpy()
                img_show = img_show.astype('uint8')
                plt.subplot(1, 3, 1)
                plt.title('image')
                plt.imshow(img_show)

                mask = target.cpu().data.numpy()
                mask = mask.reshape(mask.shape[1], mask.shape[2])
                mask = mask_to_image(mask)
                plt.subplot(1, 3, 2)
                plt.title('mask')
                plt.imshow(mask)

                predict = mask_to_image(predict)
                plt.subplot(1, 3, 3)
                plt.title('pred')
                plt.imshow(predict)

                save_file = "save_fig_val"
                os.makedirs(save_file, exist_ok=True)
                plt.savefig(os.path.join(save_file, str(i) + '.png'))
        synchronize()

    def test(self):
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        test_img_dir = '/home/wangjialei/projects/teeth_bad_case/'
        img_folder = os.listdir(test_img_dir)
        for iter, img_file in enumerate(img_folder):
            img_name = test_img_dir + img_file
            image = Image.open(img_name)
            print(type(image))
            img = data_process(image)
            img = img.to(self.device)
            with torch.no_grad():
                outputs = model(img)
            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)

                img_show = image
                plt.subplot(1, 2, 1)
                plt.title('image')
                plt.imshow(img_show)

                predict = mask_to_image(predict)
                plt.subplot(1, 2, 2)
                plt.title('pred')
                plt.imshow(predict)

                save_file = "save_fig_test"
                os.makedirs(save_file, exist_ok=True)
                plt.savefig(os.path.join(save_file, str(iter) + '.png'))
Exemple #6
0
class Evaluator(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='testval',
                                               transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            aux=args.aux,
                                            pretrained=True,
                                            pretrained_base=False)
        if args.distributed:
            self.model = self.model.module
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class)

    def eval(self):
        self.metric.reset()
        self.model.eval()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        logger.info("Start validation, Total sample: {:d}".format(
            len(self.val_loader)))
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc * 100, mIoU * 100))

            if self.args.save_pred:
                pred = torch.argmax(outputs[0], 1)
                pred = pred.cpu().data.numpy()

                predict = pred.squeeze(0)
                mask = get_color_pallete(predict, self.args.dataset)
                # mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png'))
        synchronize()