Example #1
0
class Predictor():
    def __init__(self, config,  checkpoint_path='./snapshots/checkpoint_best.pth.tar'):
        self.config = config
        self.checkpoint_path = checkpoint_path

#        with open(self.config_file_path) as f:

        self.categories_dict = {"background": 0, "short_sleeve_top": 1, "long_sleeve_top": 2, "short_sleeve_outwear": 3,
                "long_sleeve_outwear": 4, "vest": 5, "sling": 6, "shorts": 7, "trousers": 8,
                "skirt": 9,  "short_sleeve_dress": 10, "long_sleeve_dress": 11,
                "vest_dress": 12, "sling_dress": 13}

#        self.categories_dict = {"background": 0, "meningioma": 1, "glioma": 2, "pituitary": 3}
        self.categories_dict_rev = {v: k for k, v in self.categories_dict.items()}
        
        self.model = self.load_model()
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)

        self.num_classes = self.config['network']['num_classes']
        self.evaluator = Evaluator(self.num_classes)
        self.criterion = SegmentationLosses(weight=None, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])


    def load_model(self):
        model = DeepLab(num_classes=self.config['network']['num_classes'], backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'], sync_bn=False, freeze_bn=True)


        if self.config['network']['use_cuda']:
            checkpoint = torch.load(self.checkpoint_path)
        else:
            checkpoint = torch.load(self.checkpoint_path, map_location={'cuda:0': 'cpu'})

#        print(checkpoint)
        model = torch.nn.DataParallel(model)

        model.load_state_dict(checkpoint['state_dict'])

        return model

    def inference_on_test_set(self):
        print("inference on test set")

        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print("Accuracy:{}, Accuracy per class:{}, mean IoU:{}, frequency weighted IoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)



    def segment_image(self, filename):

#        file_path = os.path.join(dir_path, filename)
        img = Image.open(filename).convert('RGB')

        sample = {'image': img, 'label': img}

        sample = DeepFashionSegmentation.preprocess(sample, crop_size=513)
        image, _ = sample['image'], sample['label']
        image = image.unsqueeze(0)

        with torch.no_grad():
            prediction = self.model(image)

        image = image.squeeze(0).numpy()
        image = denormalize_image(np.transpose(image, (1, 2, 0)))
        image *= 255.

        prediction = prediction.squeeze(0).cpu().numpy()

#        print(prediction[])

        prediction = np.argmax(prediction, axis=0)

        return image, prediction
class Trainer(object):

    def __init__(self, args, modelConfig, inputH5Path):

        # Get training parameters
        hyperpars = args["hyperparameters"]
        archpars = args["architecture"]

        # Get model config
        structList = modelConfig["structList"]
        nclass = len(structList) + 1  # + 1 for background class

        args["nclass"] = nclass
        args["inputH5Path"] = inputH5Path
        if torch.cuda.device_count() and torch.cuda.is_available():
            print('Using GPU...')
            args["cuda"] = True
            deviceCount = torch.cuda.device_count()
            print('GPU device count: ', deviceCount)
        else:
            print('using CPU...')
            args["cuda"] = False

        # Use default args where missing
        defPars = {'fineTune': False, 'resumeFromCheckpoint': None, 'validate': True,
                   'evalInterval': 1}
        defHyperpars = {'startEpoch': 0}
        for key in defPars.keys():
            if not key in args.keys():
                args[key] = defPars[key]
        for key in defHyperpars.keys():
            if not key in hyperpars.keys():
                hyperpars[key] = defHyperpars[key]

        args["hyperparameters"] = hyperpars
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloaders
        kwargs = {'num_workers': 1, 'pin_memory': True}
        train_set = customData(self.args, split='Train')
        self.train_loader = DataLoader(train_set, batch_size=hyperpars["batchSize"], shuffle=True, drop_last=True,
                                       **kwargs)
        val_set = customData(self.args, split='Val')
        self.val_loader = DataLoader(val_set, batch_size=hyperpars["batchSize"], shuffle=False, drop_last=False,
                                     **kwargs)
        test_set = customData(self.args, split='Test')
        self.test_loader = DataLoader(test_set, batch_size=hyperpars["batchSize"], shuffle=False, drop_last=False,
                                      **kwargs)

        # Define network
        model = DeepLab(num_classes=args["nclass"],
                        backbone='resnet',
                        output_stride=archpars["outStride"],
                        sync_bn=archpars["sync_bn"],
                        freeze_bn=archpars["freeze_bn"],
                        model_path=args["modelSavePath"])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': hyperpars["lr"]},
                        {'params': model.get_10x_lr_params(), 'lr': hyperpars["lr"] * 10}]

        # Define Optimizer
        optimizer_type = args["optimizer"]
        if optimizer_type.lower() == 'sgd':
            optimizer = torch.optim.SGD(train_params, momentum=hyperpars["momentum"],
                                        weight_decay=hyperpars["weightDecay"], nesterov=hyperpars["nesterov"])
        elif optimizer_type.lower() == 'adam':
            optimizer = torch.optim.Adam(train_params, lr=hyperpars["lr"],
                                         betas=(0.9, 0.999), eps=1e-08,
                                         weight_decay=hyperpars["weightDecay"])

        # Initialize weights
        print('Initializing weights...')
        initWeights = args["initWeights"]
        if initWeights["method"] == "classBalanced":
            # Use class balanced weights
            print('Using class-balanced weights.')
            class_weights_path = os.path.join(inputH5Path, 'classWeights.npy')
            if os.path.isfile(class_weights_path):
                print('reading weights from' + class_weights_path)
                weight = np.load(class_weights_path)
            else:
                weight = calculate_weights_labels(inputH5Path, self.train_loader, args["nclass"])
                np.save(class_weights_path, weight)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        # Define loss function
        self.criterion = SegmentationLosses(weight=weight, cuda=args["cuda"]).build_loss(mode=args["lossType"])
        self.model, self.optimizer = model, optimizer

        # Define evaluator
        self.evaluator = Evaluator(args["nclass"])

        # Define lr scheduler
        self.scheduler = LR_Scheduler(hyperpars["lrScheduler"], hyperpars["lr"],
                                      hyperpars["maxEpochs"], len(self.train_loader))

        # Use GPU(s) if available
        if args["cuda"]:
            self.model = torch.nn.DataParallel(self.model, list(range(deviceCount)))
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resume from previous checkpoint
        self.best_pred = 0.0
        if args["resumeFromCheckpoint"] is not None:
            if not os.path.isfile(args["resumeFromCheckpoint"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args["startEpoch"] = checkpoint['epoch']
            if args["cuda"]:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            # For fine-tuning:
            if not args["fineTune"]:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args["fineTune"]:
            args["startEpoch"] = 0

    def training(self, epoch):
        args = self.args
        hyperpars = args["hyperparameters"]
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if args["cuda"]:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show inference results
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, args, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * hyperpars["batchSize"] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if not args["validate"]:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            print('Best model yet!')  # AI temp
            try:
                state_dict = self.model.module.state_dict()
            except AttributeError:
                state_dict = self.model.state_dict()
            # end mod
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

        if ((epoch + 1) == self.args.epochs):
            is_best = False
            try:
                state_dict = self.model.module.state_dict()
            except AttributeError:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                'filename': 'last_checkpoint.pth.tar',
            }, is_best)
Example #3
0
class Validator(object):
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.time_train = []
        self.args.evaluate = True
        self.args.merge = True
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        _, self.val_loader, _, self.num_class = make_data_loader(
            args, **kwargs)
        print('un_classes:' + str(self.num_class))
        self.resize = args.crop_size if args.crop_size else [512, 1024]
        self.evaluator = Evaluator(self.num_class, self.logger)
        self.model = EDCNet(self.args.rgb_dim,
                            args.event_dim,
                            num_classes=self.num_class,
                            use_bn=True)
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.to(self.args.device)
            cudnn.benchmark = True
        print('Model loaded successfully!')
        assert os.path.exists(
            args.weight_path), 'weight-path:{} doesn\'t exit!'.format(
                args.weight_path)
        self.new_state_dict = torch.load(os.path.join(args.weight_path),
                                         map_location='cuda:0')
        self.model = load_my_state_dict(self.model.module,
                                        self.new_state_dict['state_dict'])

    def validate(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        for i, (sample, gt_path) in enumerate(tbar):
            target = sample['label']
            image = sample['image']
            event = sample['event']
            if self.args.cuda:
                target = target.to(self.args.device)
                image = image.to(self.args.device)
                event = event.to(self.args.device)
            start_time = time.time()
            with torch.no_grad():
                output, output_event = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)
            if self.args.cuda:
                torch.cuda.synchronize()
            if i != 0:
                fwt = time.time() - start_time
                self.time_train.append(fwt)
                print(
                    "Forward time per img (bath size=%d): %.3f (Mean: %.3f)" %
                    (self.args.val_batch_size, fwt / self.args.val_batch_size,
                     sum(self.time_train) / len(self.time_train) /
                     self.args.val_batch_size))
            time.sleep(0.1)

            pre_colors = Colorize()(torch.max(output,
                                              1)[1].detach().cpu().byte())
            pre_colors_gt = Colorize()(torch.ByteTensor(target))
            checkname = self.args.weight_path.split('/')[-2]
            prediction_save_dir = os.path.join(self.args.label_save_path,
                                               checkname)
            if self.args.label_save:
                for j in range(pre_colors.shape[0]):
                    label_name = os.path.join(*[
                        prediction_save_dir, gt_path[j].split('gtFine/val/')
                        [1].replace('/', '_')
                    ])
                    if 'dada' in self.args.dataset:
                        label_name = label_name.replace('.jpg', '.png')
                    os.makedirs(os.path.dirname(label_name), exist_ok=True)
                    if 'dada' in self.args.dataset:
                        leftImg8bit_path = gt_path[j].replace(
                            '_labelTrainIds.png', '.jpg')
                    elif 'cityscape' in self.args.dataset:
                        leftImg8bit_path = gt_path[j].replace(
                            '_gtFine_labelTrainIds.png', '_leftImg8bit.png')
                    leftImg8bit_path = leftImg8bit_path.replace(
                        '/gtFine/', '/leftImg8bit/')
                    pre_color_image = ToPILImage()(pre_colors[j])
                    pre_colors_gt = ToPILImage()(pre_colors_gt[j])
                    img_ = Image.open(leftImg8bit_path)
                    img_ = img_.crop(
                        (280, 32, 1304, 544))  # [162, 0, 1422, 600]
                    img_ = img_.resize((self.resize[1], self.resize[0]),
                                       Image.BILINEAR)
                    event_ = ToPILImage()(event[j].cpu())
                    pre_event_ = ToPILImage()(torch.sigmoid(
                        output_event[j]).cpu())  # blur

                    if self.args.event_dim:
                        event_path = leftImg8bit_path.replace(
                            '/leftImg8bit/', '/event_image/')
                        event_path = event_path.replace(
                            '.jpg', '_event_image.png')
                        image_stack(img_, pre_color_image, pre_colors_gt,
                                    label_name, event_, pre_event_)
                    else:
                        image_stack(Image.open(leftImg8bit_path),
                                    pre_color_image, pre_colors_gt, label_name)
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.logger.info('Validation:')
        self.logger.info('[Epoch: %d, numImages: %5d]' %
                         (0, i * self.args.batch_size + target.data.shape[0]))
        self.logger.info("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
Example #4
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        # Todo: add option for other networks
        model = LaneDeepLab(args=self.args,
                            num_classes=self.nclass,
                            freeze_bn=args.freeze_bn)
        """
        model.cuda()
        summary(model, input_size=(3, 720, 1280))
        exit()
        """

        train_params = [
            {
                "params": model.get_1x_lr_params(),
                "lr": args.lr
            },
            {
                "params": model.get_10x_lr_params(),
                "lr": args.lr * 10
            },
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy")
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.ft:
                args.start_epoch = 0
            else:
                args.start_epoch = checkpoint["epoch"]

            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                # self.model.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.load_state_dict(state_dict)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        elif args.decoder is not None:
            if not os.path.isfile(args.decoder):
                raise RuntimeError(
                    "=> no checkpoint for decoder found at '{}'".format(
                        args.decoder))
            checkpoint = torch.load(args.decoder)
            args.start_epoch = (
                0  # As every time loads decoder only should be finetuning
            )
            if args.cuda:
                decoder_dict = checkpoint["state_dict"]
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in decoder_dict.items():
                    if not "aspp" in k:
                        continue
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                raise NotImplementedError("Please USE CUDA!!!")

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            lanes = sample["lanes"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, (target, lanes))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description("Train loss: %.3f" % (train_loss / (i + 1)))
            continue
            # self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            """
            if i % (num_img_tr // 10) == 0 and False:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
            """

        self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Loss: %.3f" % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )

    def validation(self, epoch, inference=False):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample["image"], sample["label"]
            lanes = sample["lanes"]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, (target, lanes))
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output[:, :-1, :, :].data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
        self.writer.add_scalar("val/mIoU", mIoU, epoch)
        self.writer.add_scalar("val/Acc", Acc, epoch)
        self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Loss: %.3f" % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )
Example #5
0
class ArchitectureSearcher(object):
    def __init__(self, args):
        self.args = args

        #Define Saver
        self.saver = Saver(args)
        #call saver function in which it is created a file
        #where informations train (like dataset,epoch..) are saved
        self.saver.save_experiment_config()

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        ##TODO: capire cosa è
        weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        model = AutoDeeplab(self.nclass, 10, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)

        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        ##TODO:capire cosa è
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)

        self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                # self.model.load_state_dict(checkpoint['state_dict'])
                copy_state_dict(self.model.state_dict(),
                                checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.resume is not None:
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            #plt.imshow(image[0].permute(2,1,0))
            #plt.show()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            #reset w.grad for each required_grad=True parameter
            self.optimizer.zero_grad()
            #compute mask prediction for image extracted from datasetA
            output = self.model(image)
            #compute lossA(Segmentation loss) between output and target
            loss = self.criterion(output, target)
            #compute loss grad respect to required_grad true parameter
            #and store the value inside x.grad
            loss.backward()
            #update w nn parameter(which are bounded with optimizer) using w.grad
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                #reset alpha&beta.grad for each required_grad=True parameter
                self.architect_optimizer.zero_grad()
                #comput mask prediction for image extracted from datasetB
                output_search = self.model(image_search)
                #compute lossB(Segmentation loss) between output and target
                arch_loss = self.criterion(output_search, target_search)
                #compute loss grad respect to required_grad true parameter
                #and store the value inside alpha&beta.grad
                arch_loss.backward()
                #update alpha&beta nn parameter(which are bounded with optimizer) using alpha&beta.grad
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred

            #state.dict() let to save, update, alter and restore Pytorch model and optimazer
            state_dict = self.model.state_dict()
            #save checkpoint to disk, in this Saver method model_best.pth is created
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #6
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Dataloader
        if args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()


        # Resuming checkpoint
        self.best_pred = 0.0
        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0 

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            self.optimizer.step()
            train_loss += loss.item()

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            print('Test Loss:%.3f' % (test_loss/(iter+1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid",
                                      "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file,
                                     phase='val',
                                     flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(
                val_data,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(
                args, **kwargs)
        elif args.dataset == 'NYUDv2':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.num_class = make_data_loader(
                args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2, 4, 23, 3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(),
                                         lr=args.lr,
                                         momentum=0,
                                         weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(),
                                        lr=args.lr,
                                        momentum=0,
                                        weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(
                weight=weight, cuda=args.cuda).build_loss(mode='ce')
        elif args.dataset == 'NYUDv2':
            weight = None
            self.criterion = SegmentationLosses(
                weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset,
                                      args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir,
                                          'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir, 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError(
                    "=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(
                load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (
                self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            elif self.args.dataset == 'NYUDv2':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(
                    epoch, iteration, len(self.train_loader), loss.data,
                    self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            elif self.args.dataset == 'NYUDv2':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.nclass = 16
        # Define network
        self.unet_model = UNet(in_channels=4, n_classes=self.nclass)
        self.unetNested_model = UNetNested(in_channels=4,
                                           n_classes=self.nclass)
        self.combine_net_model = CombineNet(in_channels=192,
                                            n_classes=self.nclass)

        train_params = [{'params': self.combine_net_model.get_params()}]
        # Define Optimizer
        self.optimizer = torch.optim.Adam(train_params,
                                          self.args.learn_rate,
                                          weight_decay=args.weight_decay,
                                          amsgrad=True)

        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.unet_model = self.unet_model.cuda()
            self.unetNested_model = self.unetNested_model.cuda()
            self.combine_net_model = self.combine_net_model.cuda()

        # Load Unet checkpoint
        if not os.path.isfile(args.unet_checkpoint_file):
            raise RuntimeError("=> no Unet checkpoint found at '{}'".format(
                args.unet_checkpoint_file))
        checkpoint = torch.load(args.unet_checkpoint_file)
        self.unet_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded Unet checkpoint '{}'".format(
            args.unet_checkpoint_file))

        # Load UNetNested checkpoint
        if not os.path.isfile(args.unetNested_checkpoint_file):
            raise RuntimeError(
                "=> no UNetNested checkpoint found at '{}'".format(
                    args.unetNested_checkpoint_file))
        checkpoint = torch.load(args.unetNested_checkpoint_file)
        self.unetNested_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded UNetNested checkpoint '{}'".format(
            args.unetNested_checkpoint_file))

        # Resuming combineNet checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    "=> no combineNet checkpoint found at '{}'".format(
                        args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.combine_net_model.module.load_state_dict(
                    checkpoint['state_dict'])
            else:
                self.combine_net_model.load_state_dict(
                    checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded combineNet checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        print('[Epoch: %d, previous best = %.4f]' % (epoch, self.best_pred))
        train_loss = 0.0
        self.combine_net_model.train()
        self.evaluator.reset()
        num_img_tr = len(train_files)
        tbar = tqdm(train_files, desc='\r')

        for i, filename in enumerate(tbar):
            image = Image.open(os.path.join(train_dir, filename))
            label = Image.open(
                os.path.join(
                    train_label_dir,
                    os.path.basename(filename)[:-4] + '_labelTrainIds.png'))
            label = np.array(label).astype(np.float32)
            label = label.reshape((1, 400, 400))
            label = torch.from_numpy(label).float()
            label = label.cuda()

            # UNet_multi_scale_predict
            unt_pred = self.unet_multi_scale_predict(image)

            # UNetNested_multi_scale_predict
            unetnested_pred = self.unetnested_multi_scale_predict(image)

            net_input = torch.cat([unt_pred, unetnested_pred], 1)

            self.optimizer.zero_grad()
            output = self.combine_net_model(net_input)
            loss = self.criterion(output, label)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

        pred = output.data.cpu().numpy()
        label = label.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        # Add batch sample into evaluator
        self.evaluator.add_batch(label, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('train/mIoU', mIoU, epoch)
        self.writer.add_scalar('train/Acc', Acc, epoch)
        self.writer.add_scalar('train/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('train/fwIoU', FWIoU, epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        print('train validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % train_loss)
        print('---------------------------------')

    def validation(self, epoch):
        test_loss = 0.0
        self.combine_net_model.eval()
        self.evaluator.reset()
        tbar = tqdm(val_files, desc='\r')
        num_img_val = len(val_files)

        for i, filename in enumerate(tbar):
            image = Image.open(os.path.join(val_dir, filename))
            label = Image.open(
                os.path.join(
                    val_label_dir,
                    os.path.basename(filename)[:-4] + '_labelTrainIds.png'))
            label = np.array(label).astype(np.float32)
            label = label.reshape((1, 400, 400))
            label = torch.from_numpy(label).float()
            label = label.cuda()

            # UNet_multi_scale_predict
            unt_pred = self.unet_multi_scale_predict(image)

            # UNetNested_multi_scale_predict
            unetnested_pred = self.unetnested_multi_scale_predict(image)

            net_input = torch.cat([unt_pred, unetnested_pred], 1)

            with torch.no_grad():
                output = self.combine_net_model(net_input)
            loss = self.criterion(output, label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1)))
            self.writer.add_scalar('val/total_loss_iter', loss.item(),
                                   i + num_img_val * epoch)
            pred = output.data.cpu().numpy()
            label = label.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(label, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('test validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        print('====================================')

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.combine_net_model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def unet_multi_scale_predict(self, image_ori: Image):
        self.unet_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unet_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unet_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    def unet_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unet_model(img)
        return output

    def unetnested_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unetNested_model(img)
        return output

    def unetnested_multi_scale_predict(self, image_ori: Image):
        self.unetNested_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unetnested_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unetnested_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    @staticmethod
    def transform_test(img):
        # Normalize
        mean = (0.544650, 0.352033, 0.384602, 0.352311)
        std = (0.249456, 0.241652, 0.228824, 0.227583)
        img = np.array(img).astype(np.float32)
        img /= 255.0
        img -= mean
        img /= std
        # ToTensor
        img = img.transpose((2, 0, 1))
        img = np.array([
            img,
        ])
        img = torch.from_numpy(img).float()
        return img
Example #9
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)



        # Define network
        #================================== network ==============================================#
        network_G = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        
        network_D = networks.define_D(4, 64, netD='basic', n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=self.args.gpu_ids)
        #=========================================================================================#



        train_params = [{'params': network_G.get_1x_lr_params(), 'lr': args.lr},
                        {'params': network_G.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        #================================== network ==============================================#
        optimizer_G = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer_D = torch.optim.Adam(network_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        #=========================================================================================#
        
        
        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        
        
        
        #======================================== criterion======================================#
        self.criterionGAN = networks.GANLoss('vanilla').to(args.gpu_ids[0])  ### set device manually
        self.criterionL1 = torch.nn.L1Loss()
        
        self.network_G, self.network_D = network_G, network_D
        self.optimizer_G, self.optimizer_D = optimizer_G, optimizer_D
        #========================================================================================#
        
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.network_G = torch.nn.DataParallel(self.network_G, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.network_G)
            self.network_G = self.network_G.cuda()


        #====================== no resume ===================================================================#
        # Resuming checkpoint
        self.best_pred = 0.0
#        if args.resume is not None:
#            if not os.path.isfile(args.resume):
#                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
#            checkpoint = torch.load(args.resume)
#            args.start_epoch = checkpoint['epoch']
#            if args.cuda:
#                self.network_G.module.load_state_dict(checkpoint['state_dict'])
#            else:
#                self.network_G.load_state_dict(checkpoint['state_dict'])
#            if not args.ft:
#                self.optimizer.load_state_dict(checkpoint['optimizer'])
#            self.best_pred = checkpoint['best_pred']
#            print("=> loaded checkpoint '{}' (epoch {})"
#                  .format(args.resume, checkpoint['epoch']))
        #=======================================================================================================#
        
        

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        
        #======================== train mode to set batch normalization =======================================#
        self.network_G.train()
        self.network_D.train()
        #======================================================================================================#
        
        
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer_G, i, epoch, self.best_pred)  ###################888888888
            
            #================================= GAN training process (pix2pix) ============================================#
            
            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #
            image = image.clone().detach().requires_grad_(True)
            output = self.network_G(image)
            
            
            
            # solve tensor size ==========================================================#
            batch_size = self.args.batch_size
            real_B = target.view(batch_size, 1, 513, 513) 
            fake_B = torch.argmax(output, dim=1).view(batch_size, 1, 513, 513).float()
            fake_AB = torch.cat((image, fake_B), 1)
            real_AB = torch.cat((image, real_B), 1)
            ### debug###########
#            print('image size')
#            print(image.size())
#            print('output size')
#            print(output.size())
#            print('target size')
#            print(target.size())
#            print('fake_AB size')
#            print(fake_AB.size())            
            # ============================================================================#
            
            self.set_requires_grad(self.network_G, False)
            self.set_requires_grad(self.network_D, True)
            self.optimizer_D.zero_grad()
            
            # fake concatenate
            pred_fake = self.network_D(fake_AB.detach())
            loss_D_fake = self.criterionGAN(pred_fake, False)
                
                
            # real concatenate
            pred_real = self.network_D(real_AB)
            loss_D_real = self.criterionGAN(pred_real, True)
            
            # combine loss and calculate gradients
            loss_D = (loss_D_fake + loss_D_real) / (2.0 * batch_size)
            
            loss_D.backward()
            self.optimizer_D.step() 
            
            
            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #
            self.set_requires_grad(self.network_G, True)
            self.set_requires_grad(self.network_D, False)
            self.optimizer_G.zero_grad()
            
            fake_AB = torch.cat((image, fake_B), 1)
            pred_fake = self.network_D(fake_AB)
            
            loss_G_GAN = self.criterionGAN(pred_fake, True)
            # L1 loss G(A) = B
            loss_G_L1 = self.criterionL1(fake_B, real_B) * 100.0 # 100.0 is lambda_L1 (weight for L1 loss)
            # combine loss and calculate gradients
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            
            self.optimizer_G.step()
            
            # display G loss
            train_loss += loss_G.item()
            #===================================================================================================#
            
            
            
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss_G.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

#======================================= no load checkpoint ==================#
#        if self.args.no_val:
#            # save checkpoint every epoch
#            is_best = False
#            self.saver.save_checkpoint({
#                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
#                'optimizer': self.optimizer.state_dict(),
#                'best_pred': self.best_pred,
#            }, is_best)
#=============================================================================#

    def validation(self, epoch):
        self.network_G.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.network_G(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        
        
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            
            #============== no save checkpoint ======================#
#            self.saver.save_checkpoint({
#                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
#                'optimizer': self.optimizer.state_dict(),
#                'best_pred': self.best_pred,
#            }, is_best)
            #=======================================================#
            
    #========================== new method ===============================# 
    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Example #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass= make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                #print(model.state_dict().keys())
                pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items()}
                
                del_list = []
                add_list = []
                for key in pretrained_dict.keys():
                    if key.split('.')[1] == 'high_level_features' and (key.split('.')[2] == '4' or key.split('.')[2] == '5' or key.split('.')[2] == '6') :
                        #pretrained_dict[key.replace('high','low')] = pretrained_dict[key] 
                        add_list.append(key)
                        del_list.append(key)
                
                for key in add_list:
                    pretrained_dict[key.replace('high','low')] = pretrained_dict[key] 
                for key in del_list:
                    del pretrained_dict[key] 


                pretrained_dict['decoder.conv1.weight'] = model.state_dict()['decoder.conv1.weight']
                # pretrained_dict['decoder.bn1.weight'] = model.state_dict()['decoder.bn1.weight']
                # pretrained_dict['decoder.bn1.bias'] = model.state_dict()['decoder.bn1.bias']
                # pretrained_dict['decoder.bn1.running_mean'] = model.state_dict()['decoder.bn1.running_mean']
                # pretrained_dict['decoder.bn1.running_var'] = model.state_dict()['decoder.bn1.running_var']
                # pretrained_dict['decoder.last_conv.0.weight'] = model.state_dict()['decoder.last_conv.0.weight'] 
                pretrained_dict['decoder.last_conv.8.weight'] = model.state_dict()['decoder.last_conv.8.weight']
                pretrained_dict['decoder.last_conv.8.bias'] = model.state_dict()['decoder.last_conv.8.bias']
                self.model.module.load_state_dict(pretrained_dict)
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            #print('target:',target.shape)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output,img = self.model(image)

            #loss = self.criterion(output, target)
            loss = L.xloss(output, target.long(), ignore=255)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)


        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        num_img_tr = len(self.train_loader)
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        F1 = 0.0
        index = 0
        FF=FT=TF=TT=0
        for i, sample in enumerate(tbar):
            image, target = sample[0]['image'], sample[0]['label']
            w = sample[1]
            h = sample[2]
            name = sample[3]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output,img = self.model(image)
            #loss = self.criterion(output, target)
            loss = L.xloss(output, target.long(), ignore=255)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            #pred = output.data.cpu().numpy()
            pred = img.data.cpu().numpy()
            #summary
            global_step = i + num_img_tr * epoch
            self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
            
            h.numpy().tolist()
            w.numpy().tolist()
            index += len(h)
            if target.size()[0] == 1:
                target = target.cpu().numpy().astype(np.uint8)   
            else:
                target = target.cpu().numpy().squeeze().astype(np.uint8)
            pred = np.argmax(pred, axis=1)

            for i in range(len(h)):
                target_ = target[i]
                pred_ = pred[i]                
                tar_img = Image.fromarray(target_)
                pre_img = Image.fromarray(pred_.squeeze().astype(np.uint8))
                tar_img = Resize((h[i],w[i]),interpolation=2)(tar_img)
                pred_ = Resize((h[i],w[i]),interpolation=2)(pre_img)
                target_ = np.array(tar_img)
                pred_ = np.array(pred_)
                pred_[pred_ != 0] = 1
                target_[target_ != 0] = 1
                pred_ = pred_.astype(int)
                target_ = target_.astype(int)
                ff, ft, tf, tt = np.bincount((target_*2+pred_).reshape(-1), minlength=4)
                #print(ff,ft,tf,tt)
                FF += ff
                FT += ft
                TF += tf
                TT += tt
            
                # F1 score 
                #F1 += self.evaluator.F1_score(target_, pred_)

            # Add batch sample into evaluator

                self.evaluator.add_batch(target_, pred_)
            
            # image_np = image[0].cpu().numpy()
            # image_np = np.array((image_np*128+128).transpose((1,2,0)),dtype=np.uint8)
            # self.writer.add_image('Input', image_np)

        R = TT / float(TT + FT)
        P = TT / float(TT + TF)
        F1 = (2*R*P)/(R+P)
        #Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        desire = (F1 + mIoU)*0.5

        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        self.writer.add_scalar('val/F1_score', F1, epoch)
        self.writer.add_scalar('val/desire', desire, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, F1_score: {}, desire: {}".format(Acc, Acc_class, mIoU, FWIoU, F1, desire))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #11
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride=args.out_stride,
                                          BatchNorm=BN)
        self.assp_model = ASPP(backbone=args.backbone,
                               output_stride=args.out_stride,
                               BatchNorm=BN)
        self.y_model = Decoder(num_classes=self.nclass,
                               backbone=args.backbone,
                               BatchNorm=BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone=args.backbone, BatchNorm=BN)
        f_params = list(self.backbone_model.parameters()) + list(
            self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(
                f_params + y_params,
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params,
                                               lr=args.lr,
                                               momentum=args.momentum,
                                               weight_decay=args.weight_decay,
                                               nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(
                f_params,
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params + y_params,
                                               lr=args.lr,
                                               momentum=args.momentum,
                                               weight_decay=args.weight_decay,
                                               nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params,
                                                   lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params + y_params,
                                                lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader,
                                                  self.nclass,
                                                  classes_weights_path,
                                                  self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(
                self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(
                self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model,
                                                 device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model,
                                                 device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(
                    checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(
                    checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(
                    checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(
                    checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(
                    checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(
                    checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(
                    checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(
                    checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        ClassName = [
            "road", "sidewalk", "building", "wall", "fence", "pole", "light",
            "sign", "vegetation", "terrain", "sky", "person", "rider", "car",
            "truck", "bus", "train", "motocycle", "bicycle"
        ]
        with open('val_info.txt', 'a') as f1:
            f1.write('Validation:' + '\n')
            f1.write('[Epoch: %d, numImages: %5d]' %
                     (epoch, i * self.args.batch_size + image.data.shape[0]) +
                     '\n')
            f1.write("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
                Acc, Acc_class, mIoU, FWIoU) + '\n')
            f1.write('Loss: %.3f' % test_loss + '\n' + '\n')
            f1.write('Class IOU: ' + '\n')
            for idx in range(19):
                f1.write('\t' + ClassName[idx] +
                         (': \t' if len(ClassName[idx]) > 5 else ': \t\t') +
                         str(IoU[idx]) + '\n')

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        print(IoU)

        new_pred = mIoU

    def imgsaver(self, img, imgname, miou):
        im1 = np.uint8(img.transpose(1, 2, 0)).squeeze()
        #filename_list = sorted(os.listdir(self.args.test_img_root))

        valid_classes = [
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]
        class_map = dict(zip(range(19), valid_classes))
        im1_np = np.uint8(np.zeros([513, 513]))
        for _validc in range(19):
            im1_np[im1 == _validc] = class_map[_validc]
        saveim1 = Image.fromarray(im1_np, mode='L')
        saveim1 = saveim1.resize((1280, 640), Image.NEAREST)
        # saveim1.save('result_val/'+imgname)

        palette = [[128, 64, 128], [244, 35, 232], [70, 70, 70],
                   [102, 102, 156], [190, 153, 153], [153, 153, 153],
                   [250, 170, 30], [220, 220, 0], [107, 142, 35],
                   [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
                   [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
                   [0, 0, 230], [119, 11, 32]]
        #[0,0,0]]
        class_color_map = dict(zip(range(19), palette))
        im2_np = np.uint8(np.zeros([513, 513, 3]))
        for _validc in range(19):
            im2_np[im1 == _validc] = class_color_map[_validc]
        saveim2 = Image.fromarray(im2_np)
        saveim2 = saveim2.resize((1280, 640), Image.NEAREST)
        saveim2.save('result_val/' + imgname[:-4] + '_color_' + str(miou) +
                     '_.png')
        # print('saving: '+filename_list[idx])

    def validationSep(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            self.evaluator.reset()
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
            self.imgsaver(pred, sample['name'][0], mIoU)
Example #12
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        if self.args.sync_bn:
            self.args.batchnorm_function = SynchronizedBatchNorm2d
        else:
            self.args.batchnorm_function = torch.nn.BatchNorm2d
        print(self.args)
        # Define Saver
        self.saver = Saver(self.args)
        # Define Tensorboard Summary
        self.summary = TensorboardSummary()
        self.writer = self.summary.create_summary(self.saver.experiment_dir)

        # Define Dataloader
        kwargs = {'num_workers': self.args.gpus, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader = make_data_loader(
            self.args, **kwargs)
        self.nclass = self.args.num_classes

        # Define network
        model = generate_net(self.args)
        train_params = [{
            'params': model.get_conv_weight_params(),
            'lr': self.args.lr,
            'weight_decay': self.args.weight_decay
        }, {
            'params': model.get_conv_bias_params(),
            'lr': self.args.lr * 2,
            'weight_decay': 0
        }]

        # Define Optimizer
        if self.args.optim_method == 'sgd':
            optimizer = torch.optim.SGD(train_params,
                                        momentum=self.args.momentum,
                                        lr=self.args.lr,
                                        weight_decay=0,
                                        nesterov=self.args.nesterov)
        elif self.args.optim_method == 'adagrad':
            optimizer = torch.optim.Adagrad(
                train_params,
                lr=self.args.lr,
                weight_decay=self.args.weight_decay)
        else:
            pass

        # Define Criterion
        # whether to use class balanced weights
        if self.args.use_balanced_weights:
            classes_weights_path = os.path.join(self.args.save_dir,
                                                self.args.dataset,
                                                'classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.args.save_dir,
                                                  self.args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32)).type(
                torch.FloatTensor)
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight,
            cuda=self.args.cuda,
            foreloss_weight=args.foreloss_weight,
            seloss_weight=args.seloss_weight).build_loss(
                mode=self.args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        self.evaluator_inner = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.args.lr_scheduler, self.args.lr,
                                      self.args.epochs, len(self.train_loader))

        self.model = self.model.cuda()
        # Resuming checkpoint
        self.args.start_epoch = 0
        self.best_pred = 0.0
        if self.args.resume is not None:
            optimizer, start_epoch, best_pred = load_pretrained_mode(
                self.model, checkpoint_path=self.args.resume)
            if not self.args.ft and optimizer is not None:
                self.optimizer.load_state_dict(optimizer)
                self.args.start_epoch = start_epoch
                self.best_pred = best_pred
        # Using cuda
        if self.args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        self.evaluator.reset()
        self.evaluator_inner.reset()
        print('Training')
        start_time = time.time()
        for i, sample in enumerate(self.train_loader):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            current_lr = self.scheduler(self.optimizer, i, epoch,
                                        self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss, output = self.criterion(output, target)
            pred = output.data.clone()
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            pred = pred.data.cpu().numpy()
            target_array = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator_inner.add_batch(target_array, pred)
            self.evaluator.add_batch(target_array, pred)
            if i % 10 == 0:
                Acc_train = self.evaluator_inner.Pixel_Accuracy()
                Acc_class_train = self.evaluator_inner.Pixel_Accuracy_Class()
                mIoU_train, IoU_train = self.evaluator_inner.Mean_Intersection_over_Union(
                )
                FWIoU_train = self.evaluator_inner.Frequency_Weighted_Intersection_over_Union(
                )
                print(
                    '\n===>Iteration  %d/%d    learning_rate: %.6f   metric:' %
                    (i, num_img_tr, current_lr))
                print(
                    '=>Train loss: %.4f    acc: %.4f     m_acc: %.4f     miou: %.4f     fwiou: %.4f'
                    % (loss.item(), Acc_train, Acc_class_train, mIoU_train,
                       FWIoU_train))
                print("IoU per class: ", IoU_train)
                self.evaluator_inner.reset()

            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if num_img_tr > 10:
                if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    self.summary.visualize_image(self.writer,
                                                 self.args.dataset, image,
                                                 target, output, global_step)
            else:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)
        Acc_train_epoch = self.evaluator.Pixel_Accuracy()
        Acc_class_train_epoch = self.evaluator.Pixel_Accuracy_Class()
        mIoU_train_epoch, IoU_train_epoch = self.evaluator.Mean_Intersection_over_Union(
        )
        FWIoU_train_epoch = self.evaluator.Frequency_Weighted_Intersection_over_Union(
        )
        stop_time = time.time()
        self.writer.add_scalar('train/total_loss_epoch',
                               train_loss / num_img_tr, epoch)
        print(
            '=====>[Epoch: %d, numImages: %5d   time_consuming: %d]' %
            (epoch, num_img_tr * self.args.batch_size, stop_time - start_time))
        print(
            "Loss: %.3f  Acc: %.4f,  Acc_class: %.4f,  mIoU: %.4f,  fwIoU: %.4f\n\n"
            % (train_loss / (num_img_tr), Acc_train_epoch,
               Acc_class_train_epoch, mIoU_train_epoch, FWIoU_train_epoch))
        print("IoU per class: ", IoU_train_epoch)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        print('\nValidation')
        num_img_tr = len(self.val_loader)
        start_time = time.time()
        for i, sample in enumerate(self.val_loader):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                loss, output = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
        stop_time = time.time()
        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss / num_img_tr,
                               epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print(
            '=====>[Epoch: %d, numImages: %5d   previous best=%.4f    time_consuming: %d]'
            % (epoch, num_img_tr * self.args.gpus, self.best_pred,
               (stop_time - start_time)))
        print(
            "Loss: %.3f  Acc: %.4f,  Acc_class: %.4f,  mIoU: %.4f,  fwIoU: %.4f\n\n"
            % (test_loss / (num_img_tr), Acc, Acc_class, mIoU, FWIoU))
        print("IoU per class: ", IoU)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        else:
            is_best = False
        self.saver.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': new_pred,
            }, is_best)
Example #13
0
class Valuator(object):
    def __init__(self, args):
        self.args = args
        self.args.batchnorm_function = torch.nn.BatchNorm2d
        # Define Dataloader
        self.nclass = self.args.num_classes
        # Define network
        model = generate_net(self.args)

        self.model = model
        self.evaluator = Evaluator(self.nclass)
        self.criterion = SegmentationLosses(cuda=True).build_loss(mode='ce')
        # Using cuda
        if self.args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        _, _, _ = load_pretrained_mode(self.model,
                                       checkpoint_path=self.args.resume)

    def visual(self):
        self.model.eval()
        print('\nvisualizing')
        self.evaluator.reset()
        data_dir = self.args.data_dir
        data_list = os.path.join(data_dir, self.args.val_list)
        vis_set = GenDataset(self.args, data_list, split='vis')
        vis_loader = DataLoader(vis_set, batch_size=1, shuffle=False)
        num_img_tr = len(vis_loader)
        print('=====>[numImages: %5d]' % (num_img_tr))
        for i, sample in enumerate(vis_loader):
            image, target, name, ori = sample['image'], sample[
                'label'], sample['name'], sample['ori']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                if isinstance(output, (tuple, list)):
                    output = output[0]
            output = torch.nn.functional.interpolate(output,
                                                     size=ori.size()[1:3],
                                                     mode='bilinear',
                                                     align_corners=True)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            ori = ori.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            if num_img_tr > 100:
                if i % (num_img_tr // 100) == 0:
                    self.save_img(ori, target, pred, name)
            else:
                self.save_img(ori, target, pred, name)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("IoU per class: ", IoU)

    def save_img(self, images, labels, predictions, names):
        save_dir = self.args.save_dir
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        num_image = len(labels)
        labels = decode_seg_map_sequence(labels).cpu().numpy().transpose(
            0, 2, 3, 1)
        predictions = decode_seg_map_sequence(
            predictions).cpu().numpy().transpose(0, 2, 3, 1)
        for i in range(num_image):
            name = names[i]
            if not isinstance(name, str):
                name = str(name)
            save_name = os.path.join(save_dir, name + '.png')
            image = images[i, :, :, :]
            label_mask = labels[i, :, :, :]
            prediction = predictions[i, :, :, :]
            if image.shape != label_mask.shape:
                print('error in %s' % name)
                continue
            label_map = self.addImage(image.astype(dtype=np.uint8),
                                      label_mask.astype(dtype=np.uint8))
            pred_map = self.addImage(image.astype(dtype=np.uint8),
                                     prediction.astype(dtype=np.uint8))
            label = img.fromarray(label_map.astype(dtype=np.uint8), mode='RGB')
            pred = img.fromarray(pred_map.astype(dtype=np.uint8), mode='RGB')
            label_mask = img.fromarray(label_mask.astype(dtype=np.uint8),
                                       mode='RGB')
            pred_mask = img.fromarray(prediction.astype(dtype=np.uint8),
                                      mode='RGB')
            shape1 = label.size
            shape2 = pred.size
            assert (shape1 == shape2)
            width = 2 * shape1[0] + 60
            height = 2 * shape1[1] + 60
            toImage = img.new('RGB', (width, height))
            toImage.paste(pred, (0, 0))
            toImage.paste(label, (shape1[0] + 60, 0))
            toImage.paste(pred_mask, (0, shape1[1] + 60))
            toImage.paste(label_mask, (shape1[0] + 60, shape1[1] + 60))
            toImage.save(save_name)

    def addImage(self, img1_path, img2_path):
        alpha = 1
        beta = 0.7
        gamma = 0
        img_add = cv2.addWeighted(img1_path, alpha, img2_path, beta, gamma)
        return img_add
Example #14
0
class Tester(object):
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        val_set = pascal.VOCSegmentation(args, split='val')
        self.nclass = val_set.NUM_CLASSES
        self.val_loader = DataLoader(val_set,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     **kwargs)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            print('device_ids', self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

    def visualization(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        num_img_val = len(self.val_loader)
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            # Save images, predictions, targets into disk
            if i % (num_img_val // 10) == 0:
                self.save_batch_images(image, output, target, i)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            # if i == 0:
            #     break

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        # mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        mIoU = self.evaluator.All_Mean_Intersection_over_Union()

        print('Validation:')
        print("Acc:{}, Acc_class:{}, fwIoU: {}".format(Acc, Acc_class, FWIoU))
        print("mIoU:{:.4f} {:.4f} {:.4f} {:.4f}".format(
            mIoU[0], mIoU[1], mIoU[2], mIoU[3]))
        print('Loss: %.3f' % test_loss)

    def save_batch_images(self, imgs, preds, targets, batch_index):
        (filepath, _) = os.path.split(self.args.resume)
        save_path = os.path.join(filepath, 'visualization')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        grid_image = make_grid(imgs.clone().detach().cpu(), 8, normalize=True)
        save_image(
            grid_image,
            os.path.join(save_path,
                         'batch_{:0>4}-img.jpg'.format(batch_index)))
        grid_image = make_grid(decode_seg_map_sequence(
            torch.max(preds, 1)[1].detach().cpu().numpy(),
            dataset=self.args.dataset),
                               8,
                               normalize=False,
                               range=(0, 255))
        save_image(
            grid_image,
            os.path.join(save_path,
                         'batch_{:0>4}-pred.png'.format(batch_index)))
        grid_image = make_grid(decode_seg_map_sequence(
            torch.squeeze(targets, 1).detach().cpu().numpy(),
            dataset=self.args.dataset),
                               8,
                               normalize=False,
                               range=(0, 255))
        save_image(
            grid_image,
            os.path.join(save_path,
                         'batch_{:0>4}-target.png'.format(batch_index)))
Example #15
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = MyDeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        #optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)

        # adam
        optimizer = torch.optim.Adam(params=self.model.parameters(),betas=(0.9, 0.999),
                                    eps=1e-08, weight_decay=0, amsgrad=False)

        weight = [1, 10, 10, 10, 10, 10, 10, 10]
        weight = torch.tensor(weight, dtype=torch.float) 
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda, num_classes=self.nclass).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
        
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        '''
        # 获取当前模型各层的名称
        layer_name = list(self.model.state_dict().keys())
        #print(self.model.state_dict()[layer_name[3]])
        # 加载通用的预训练模型
        pretrained = './pretrained_model/deeplab-mobilenet.pth.tar'
        pre_ckpt = torch.load(pretrained)
        key_name = list(checkpoint['state_dict'].keys()) # 获取预训练模型各层的名称
        pre_ckpt['state_dict'][key_name[-2]] = checkpoint['state_dict'][key_name[-2]] # 类别不同,最后两层单独赋值
        pre_ckpt['state_dict'][key_name[-1]] = checkpoint['state_dict'][key_name[-1]]
        self.model.module.load_state_dict(pre_ckpt['state_dict'])     # , strict=False)
        #print(self.model.state_dict()[layer_name[3]])
        print("加载预训练模型ok")
        '''
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            #import pdb
            #pdb.set_trace()
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            #if (i+1) % 50 == 0:
            #    print('Train loss: %.3f' % (loss.item() / (i + 1)))
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        filename='checkpoint_{}_{:.4f}.pth.tar'.format(epoch, train_loss)
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best, filename=filename)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            #if (i+1) %20 == 0:
            #    print('Test loss: %.3f' % (loss / (i + 1)))
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #16
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        self.t_net = DeepLab(num_classes=self.nclass,
                             backbone='resnet101',
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        checkpoint = torch.load('pretrained/deeplab-resnet.pth.tar')
        self.t_net.load_state_dict(checkpoint['state_dict'])

        self.s_net = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.d_net = distiller.Distiller(self.t_net, self.s_net)

        print('Teacher Net: ')
        print(self.t_net)
        print('Student Net: ')
        print(self.s_net)
        print('the number of teacher model parameters: {}'.format(
            sum([p.data.nelement() for p in self.t_net.parameters()])))
        print('the number of student model parameters: {}'.format(
            sum([p.data.nelement() for p in self.s_net.parameters()])))

        self.distill_ratio = 1e-5
        self.batch_size = args.batch_size

        distill_params = [{
            'params': self.s_net.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': self.s_net.get_10x_lr_params(),
            'lr': args.lr * 10
        }, {
            'params': self.d_net.Connectors.parameters(),
            'lr': args.lr * 10
        }]

        init_params = [{
            'params': self.d_net.Connectors.parameters(),
            'lr': args.lr * 10
        }]

        # # Define Optimizer
        self.optimizer = torch.optim.SGD(distill_params,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=args.nesterov)
        self.init_optimizer = torch.optim.SGD(init_params,
                                              momentum=args.momentum,
                                              weight_decay=args.weight_decay,
                                              nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.s_net = torch.nn.DataParallel(self.s_net).cuda()
            self.d_net = torch.nn.DataParallel(self.d_net).cuda()

        # Resuming checkpoint
        self.best_pred = 0.0

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.d_net.train()
        self.d_net.module.t_net.train()
        self.d_net.module.s_net.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        if epoch == 0:
            optimizer = self.init_optimizer
        else:
            optimizer = self.optimizer

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            batch_size = image.shape[0]
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(optimizer, i, epoch, self.best_pred)
            optimizer.zero_grad()
            output, loss_distill = self.d_net(image)

            loss_seg = self.criterion(output, target)
            loss = loss_seg + loss_distill.sum() / batch_size * 1e-5

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.s_net.module.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.s_net.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.s_net(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.s_net.module.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #17
0
for i, sample in enumerate(val_loader):
    batch, target = sample['image'], sample['label']
    if args.cuda:
        batch = batch.cuda()
    input = Variable(batch)
    output = model(input)
    pred = output.data.cpu().numpy()
    target = target.cpu().numpy()
    pred = np.argmax(pred, axis=1)
    # Add batch sample into evaluator
    evaluator.add_batch(target, pred)

Acc = evaluator.Pixel_Accuracy() * 100
Acc_class = evaluator.Pixel_Accuracy_Class() * 100
mIoU = evaluator.Mean_Intersection_over_Union() * 100
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() * 100
print("Acc:{0:.3f}, Acc_class:{1:.3f}, mIoU:{2:.3f}, fwIoU: {3:.3f}".format(
    Acc, Acc_class, mIoU, FWIoU))

# model_acc[curr_iter] = Acc
# model_acc_class[curr_iter] = Acc_class
# model_miou[curr_iter] = mIoU

title = 'Inference Time of Pruned Model'
plt.plot(total_infer_times, '-b', label='infer_times')
plt.plot(model_acc, '-r', label='accuracy')
plt.plot(model_acc_class, '-g', label='accuracy_class')
plt.plot(model_miou, '-k', label='miou')
plt.legend(loc='upper left')
plt.xlabel("pruning iterations")
plt.title(title)
Example #18
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                #if so, which trainloader to use?
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(num_classes=self.nclass,
                            num_layers=12,
                            criterion=self.criterion,
                            filter_multiplier=self.args.filter_multiplier)
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)

        self.architect = Architect(self.model, args)

        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            if epoch > self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()
                self.architect.step(image_search, target_search)

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #19
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.vs = Vs(args.dataset)

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        if self.args.norm == "gn":
            norm = gn
        elif self.args.norm == "bn":
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == "abn":
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        model = LaneDeepLab(args=args, num_classes=self.nclass)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset), args.dataset + "_classes_weights.npy"
            )
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(
                    args.dataset, self.train_loader, self.nclass
                )
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(
            mode=args.loss_type
        )
        self.model = model

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            self.best_pred = checkpoint["best_pred"]
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def test(self):
        self.model.eval()
        self.args.examine = False
        tbar = tqdm(self.test_loader, desc="\r")
        if self.args.color:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images = sample["image"]
            names = sample["name"]
            if self.args.cuda:
                images = images.cuda()
            with torch.no_grad():
                output = self.model(images)
            preds = output.data.cpu().numpy()
            preds = np.argmax(preds, axis=1)
            if __image:
                images = images.cpu().numpy()
            if not self.args.color:
                self.vs.predict_id(preds, names, self.args.save_dir)
            else:
                self.vs.predict_color(preds, images, names, self.args.save_dir)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        if self.args.color or self.args.examine:
            __image = True
        else:
            __image = False
        for i, sample in enumerate(tbar):
            images, targets = sample["image"], sample["label"]
            names = sample["name"]
            lanes = sample["lanes"]
            if self.args.cuda:
                images, targets = images.cuda(), targets.cuda()
                lanes = lanes.cuda().unsqueeze(1)
            with torch.no_grad():
                output = self.model(images)
            # debug:
            # print("output.shape: ", output.shape)
            # print("targets.shape: ", targets.shape)
            loss = self.criterion(output, (targets, lanes))
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            preds = output.data.cpu().numpy()
            targets = targets.cpu().numpy()
            preds = np.argmax(
                preds[:, 0:-1, :, :], axis=1
            )  # since the last channel is lane lines
            # Add batch sample into evaluator
            # debug:
            # print("targets.shape: ", targets.shape)
            # print("preds.shape: ", preds.shape)
            self.evaluator.add_batch(targets, preds)  # originla
            if __image:
                images = images.cpu().numpy()
            if self.args.id:
                self.vs.predict_id(preds, names, self.args.save_dir)
            if self.args.color:
                self.vs.predict_color(preds, images, names, self.args.save_dir)
            if self.args.examine:
                self.vs.predict_examine(
                    preds, targets, images, names, self.args.save_dir
                )

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Validation:")
        # print(
        #     "[Epoch: %d, numImages: %5d]"
        #     % (epoch, i * self.args.batch_size + image.data.shape[0])
        # )
        print(
            "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
                Acc, Acc_class, mIoU, FWIoU
            )
        )
        print("Loss: %.3f" % test_loss)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        self.saver = Saver(args)
        self.saver.save_experiment_config()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # self.model = OCRNet(self.nclass)
        self.model = build_model(2, [32, 32], '44330020')
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=args.nesterov)
        if args.use_balanced_weights:
            weight = torch.tensor([0.2, 0.8], dtype=torch.float32)
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.evaluator = Evaluator(self.nclass)
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        if args.cuda:
            self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()

        print('[Epoch:{},num_images:{}]'.format(
            epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss:{}'.format(train_loss))

        if self.args.no_val:
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        road_iou, mIOU = self.evaluator.Mean_Intersection_over_Union()
        FWIOU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:\n')
        print('[Epoch:{},num_image:{}]'.format(
            epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Acc:{},Acc_class:{},mIOU:{},road_iou:{},fwIOU:{}'.format(
            Acc, Acc_class, mIOU, road_iou, FWIOU))
        print('Loss:{}'.format(test_loss))

        new_pred = road_iou
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define network
        model = DeepLab(num_classes=args.num_classes,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, self.valid_loader = make_data_loader(args, **kwargs)
        self.pred_remap = args.pred_remap
        self.gt_remap = args.gt_remap

        # Define Evaluator
        self.evaluator = Evaluator(args.eval_num_classes)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

    def validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.valid_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)

            if self.gt_remap is not None:
                target = self.gt_remap[target.astype(int)]
            if self.pred_remap is not None:
                pred = self.pred_remap[pred.astype(int)]
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
Example #22
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # unzip data file
        if not os.path.exists("../../data/jingwei_round1_train_20190619"):
            print("=> unzip data files...")
            os.system(
                'unzip ../../data/jingwei_round1_train_20190619.zip -d ../../data'
            )
            os.system(
                'unzip ../../data/jingwei_round1_test_a_20190619.zip -d ../../data'
            )

        # Generate .npy file for dataloader
        self.img_process(args)

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = getattr(modeling, args.model_name)(pretrained=args.pretrained)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    # 将大图按unit_size的大小,每次stride的移动量进行裁剪。将分好的训练集和验证机以np数组形式存储在save_dir中,
    # 方便下次使用,并减少内存的占用。请将路径修改为自己的。
    def img_process(self, args):
        print("=> starting model_2 img_process...")
        unit_size = args.base_size
        stride = unit_size  # int(unit_size/2)
        if not os.path.exists('../../data/train_model_2'):
            os.mkdir('../../data/train_model_2')
        save_dir = os.path.join('../../data/train_model_2', str(unit_size))
        # npy_process
        if not os.path.exists(save_dir):

            Image.MAX_IMAGE_PIXELS = 100000000000
            # load train image 1
            img = Image.open(
                '../../data/jingwei_round1_train_20190619/image_1.png')
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '../../data/jingwei_round1_train_20190619/image_1_label.png')
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img1 = []  # 保存小图的数组
            Label1 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img1.append(im[:, :, 0:3])  # 添加小图
                        Label1.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img1 = np.array(Img1)  #(4123, 448, 448, 3)
            Label1 = np.array(Label1)  #(4123, 448, 448)

            # load train image 2
            img = Image.open(
                '../../data/jingwei_round1_train_20190619/image_2.png')
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '../../data/jingwei_round1_train_20190619/image_2_label.png')
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img2 = []  # 保存小图的数组
            Label2 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img2.append(im[:, :, 0:3])  # 添加小图
                        Label2.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img2 = np.array(Img2)  #(5072, 448, 448, 3)
            Label2 = np.array(Label2)  #(5072, 448, 448)

            Img = np.concatenate((Img1, Img2), axis=0)
            cat = np.concatenate((Label1, Label2), axis=0)

            # shuffle
            np.random.seed(1)
            assert (Img.shape[0] == cat.shape[0])
            shuffle_id = np.arange(Img.shape[0])
            np.random.shuffle(shuffle_id)
            Img = Img[shuffle_id]
            cat = cat[shuffle_id]

            os.mkdir(save_dir)
            print("=> generate {}".format(unit_size))
            # split train dataset
            images_train = Img[:int(Img.shape[0] * 0.8)]
            categories_train = cat[:int(cat.shape[0] * 0.8)]
            assert (len(images_train) == len(categories_train))
            np.save(os.path.join(save_dir, 'train_img.npy'), images_train)
            np.save(os.path.join(save_dir, 'train_label.npy'),
                    categories_train)
            # split val dataset
            images_val = Img[int(Img.shape[0] * 0.8):]
            categories_val = cat[int(cat.shape[0] * 0.8):]
            assert (len(images_val) == len(categories_val))
            np.save(os.path.join(save_dir, 'val_img.npy'), images_val)
            np.save(os.path.join(save_dir, 'val_label.npy'), categories_val)

            print("=> img_process finished!")
        else:
            print("{} file already exists!".format(unit_size))
        for x in locals().keys():
            del locals()[x]
        # 释放内存
        import gc
        gc.collect()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #23
0
class Infer(object):
    def __init__(self,args):
        self.args = args
        self.nclass  = 4
        self.save_fold = 'brain_re/brain_cedice'
        mkdir(self.save_fold)
        self.name = self.save_fold.split('/')[-1].split('_')[-1]
        #===for brain==========================
        # self.nclass = 4
        # self.save_fold = 'brain_re'
        #======================================
        net = segModel(self.args,self.nclass)
        net.build_model()
        model = net.model
        #load params
        resume = args.resume
        self.model = torch.nn.DataParallel(model)
        self.model = self.model.cuda()
        print('==>Load model...')
        if not resume is None:
            checkpoint = torch.load(resume)
            # model.load_state_dict(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
        self.model = model
        print('==>loding loss func...')
        self.criterion = SegmentationLosses(cuda=args.cuda).build_loss(mode=args.loss_type)

        #define evaluator
        self.evaluator = Evaluator(self.nclass)

        #get data path
        root_path = Path.db_root_dir(self.args.dataset)
        if self.args.dataset == 'drive':
            folder = 'test'
            self.test_img = os.path.join(root_path, folder, 'images')
            self.test_label = os.path.join(root_path, folder, '1st_manual')
            self.test_mask = os.path.join(root_path, folder, 'mask')
        elif self.args.dataset == 'brain':
            path = root_path+'/Bra-pickle'
            valid_path = '../data/Brain/test.csv'
            self.valid_set = get_dataset(path,valid_path)
        print('loading test data...')

        #define data
        self.test_loader = None

    def eval(self):
        gt_name = os.listdir(self.test_label)
        img_list = [os.path.join(self.test_label, image) for image in gt_name]
        mask_listdir = [os.path.join(self.test_mask,image.split('.')[0].split('_')[0]+'_test_mask.gif') for image in gt_name]
        pred_list = get_result_list(gt_name,self.save_fold)
        #transform
        for i in range(len(img_list)):
            target, preds, _mask = img_list[i],pred_list[i],mask_listdir[i]
            self.evaluator.add_batch(target, preds,mask=_mask)
        #idx = len(img_list)
        idx=1
        test_Acc = self.evaluator.Pixel_Accuracy()
        test_acc_class = self.evaluator.Pixel_Accuracy_Class()
        test_mIou = self.evaluator.Mean_Intersection_over_Union()
        test_fwiou = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        pre,recall,auc=self.evaluator.show_Roc()
        print('Test:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, precision:{}, Recall:{}, Auc:{}".format(test_Acc / idx, test_acc_class / idx, test_mIou / idx,
                                                                test_fwiou / idx,pre,recall,auc))


    def predict_a_patch(self):
        self.model.eval()
        imgs = os.listdir(self.test_img)
        labels = []
        for i in imgs:
            label_name = (i.split('.')[0]).split('_')[0]+'_manual1.gif'
            labels.append(label_name)
        img_list = [os.path.join(self.test_img,image) for image in imgs]
        label_list = [os.path.join(self.test_label,lab) for lab in labels]

        #some params
        patch_h = self.args.ph
        patch_w = self.args.pw
        stride_h = self.args.sh
        stride_w = self.args.sw
        #crop imgs to patches
        images_patch, labels_patch, Height, Width,self.gray_original = extract_patches_test(img_list, label_list, patch_h, patch_w, stride_h,
                                                           stride_w)  # list[patches]
        data = []
        for i, j in zip(images_patch, labels_patch):
            data.append((i, j))

        #start test one batch has one image
        tbar = tqdm(data)
        for idx,sample in enumerate(tbar):
            image,target = sample[0],sample[1]
            #print(image.shape,target.shape)
            image,target = image.cuda(),target.cuda()
            with torch.no_grad():
                result = self._predict_a_patch(image)
            preds = result
            full_preds = merge_overlap(preds, Height, Width, stride_h, stride_w)  # Tensor->[1,1,H,W]
            full_preds = full_preds[0,1,:,:]
            full_img = tfs.ToPILImage()((full_preds*255).type(torch.uint8))
            full_image = (full_preds>=0.5)*1#0.5
            mergeImage = merge(self.gray_original[idx],full_image)
            #save result image
            name_probs = imgs[idx].split('.')[0].split('_')[0]+'_test_prob.bmp'
            name_merge = imgs[idx].split('.')[0].split('_')[0]+'_merge.bmp'
            save(mergeImage,os.path.join(self.save_fold,name_merge))
            save(full_img,os.path.join(self.save_fold,name_probs))

    def _predict_a_patch(self, patchs):
        number_of_patch = patchs.shape[0]
        results = torch.zeros(number_of_patch,self.nclass,
                            self.args.ph, self.args.pw)
        results = results.cuda()
        patchs = patchs.float()

        steps  = int(number_of_patch / self.args.batch_size)
        #step  = tqdm(steps)
        for i in range(steps):
            start_index = i*self.args.batch_size
            end_index   = start_index + self.args.batch_size
            output  = self.model( patchs[start_index:end_index] )
            output      = torch.sigmoid( output )
            results[start_index:end_index] = output
        results[end_index:] = torch.sigmoid(self.model(patchs[end_index:]))
        return results

    def test(self):
        self.model.eval()
        print(self.model)
        self.evaluator.reset()
        self.test_loader = DataLoader(self.valid_set,batch_size=self.args.test_batch_size,shuffle=False)
        tbar = tqdm(self.test_loader, desc='\r')#need to rewrite
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            #show(image[0].permute(1,2,0).numpy(),target[0].numpy())
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
                _,pred = output.max(1)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            target = target.cpu().numpy()
            pred = pred.cpu().numpy()

            #show
            if i >= 0 and i<=100:
                iii = image[0].cpu().numpy()
                showimg = np.transpose(iii,(1,2,0))
                plt.figure()
                plt.imshow(showimg,cmap='gray')
                plt.show()
                fname = self.save_fold+'/'+self.name+'_'+str(i)+'.png'
                show(image[0].permute(1,2,0).cpu().numpy(),target[0],pred[0],fname)
            # if i>99:
            #     break
            #save(pred[0],fname=str(i)+'.jpg')
            #

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        Dice_coff = self.evaluator.DiceCoff()
        P, R, perclass = self.evaluator.compute_el()
        print('Test:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, dice:{}".format(Acc, Acc_class, mIoU, FWIoU,Dice_coff))
        print('precision:{},recall:{}'.format(P,R))
        print('preclass, pre{},recall:{}'.format(perclass[0],perclass[1]))
        print('Loss: %.3f' % test_loss)
Example #24
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        # self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if self.args.dataset == '2d':
            self.data_dict, self.nclass = make_data_loader(args, **kwargs)
        elif self.args.dataset == '3d':
            self.data_dict, self.nclass = make_data_loader_3d_patch(
                args, **kwargs)
        print('#' * 35, 'Load data down!')
        # if args.use_balanced_weights:
        # classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
        # if os.path.isfile(classes_weights_path):
        #     weight = np.load(classes_weights_path)
        # else:
        #     raise NotImplementedError
        #     #if so, which trainloader to use?
        #     # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        # weight = torch.from_numpy(weight.astype(np.float32))
        # else:
        weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 8, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(  #这是trainA的optimizer
            model.weight_parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            #这是trainB 的optimizer只优化alpha和beta
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                               args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      self.data_dict['num_train'],
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        # checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loaderA)
        # num_img_tr = len(self.train_loaderA)
        num_img_tr = self.data_dict['num_train']
        num_img_valid = self.data_dict['num_valid']
        for i in range(num_img_tr):
            image, target = torch.FloatTensor(
                self.data_dict['train_data'][i]), torch.FloatTensor(
                    self.data_dict['train_mask'][i])
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch % self.args.alpha_epoch == 0:
                # search = next(iter(self.train_loaderB))
                #其实就是读取B的数据优化alpha和beta
                for j in range(num_img_valid):
                    image_search, target_search = torch.FloatTensor(
                        self.data_dict['valid_data'][j]), torch.FloatTensor(
                            self.data_dict['valid_mask'][j])
                    if self.args.cuda:
                        image_search, target_search = image_search.cuda(
                        ), target_search.cuda()

                    self.architect_optimizer.zero_grad()
                    output_search = self.model(image_search)
                    arch_loss = self.criterion(output_search, target_search)
                    if self.use_amp:
                        with amp.scale_loss(
                                arch_loss,
                                self.architect_optimizer) as arch_scaled_loss:
                            arch_scaled_loss.backward()
                    else:
                        arch_loss.backward()
                    self.architect_optimizer.step()

            train_loss += loss.item()
            print('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 1) == 0:
            #     global_step = i + num_img_tr * epoch
            #     self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, num_img_tr))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            # if torch.cuda.device_count() > 1:
            #     state_dict = self.model.module.state_dict()
            # else:
            state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
            print('save checkpoint')

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        # tbar = tqdm(self.val_loader, desc='\r')

        test_loss = 0.0
        num_img_test = self.data_dict['num_test']

        for i in range(num_img_test):
            image, target = torch.FloatTensor(
                self.data_dict['test_data'][i]), torch.FloatTensor(
                    self.data_dict['test_mask'][i])
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            print('Train loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        #### if initializer
        if args.init is not None:
            model = DeepLab(num_classes=21,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)
        else:
            model = DeepLab(num_classes=self.nclass,
                            backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr
        }]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)

        optimizer = torch.optim.Adam(train_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     amsgrad=True)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                                    args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        #initializing network
        if args.init is not None:
            if not os.path.isfile(args.init):
                raise RuntimeError(
                    "=> no initializer checkpoint found at '{}'".format(
                        args.init))
            checkpoint = torch.load(args.init)
            #args.start_epoch = checkpoint['epoch']
            state_dict = checkpoint['state_dict']
            # del state_dict["decoder.last_conv.8.weight"]
            # del state_dict["decoder.last_conv.8.bias"]
            if args.cuda:
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                self.model.load_state_dict(state_dict, strict=False)
            # if not args.ft:
            #     self.optimizer.load_state_dict(checkpoint['optimizer'])
            # self.best_pred = checkpoint['best_pred']
            self.model.module.decoder.last_layer = nn.Conv2d(256,
                                                             self.nclass,
                                                             kernel_size=1,
                                                             stride=1).cuda()
            print("=> loaded initializer '{}' (epoch {})".format(
                args.init, checkpoint['epoch']))

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            ##state_dict = checkpoint['state_dict']
            ## del state_dict["decoder.last_conv.8.weight"]
            ## del state_dict["decoder.last_conv.8.bias"]
            if args.cuda:
                #self.model.module.load_state_dict(state_dict, strict=False)
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                #self.model.load_state_dict(state_dict, strict=False)
                self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

            #self.model.module.decoder.last_layer = nn.Conv2d(256, self.nclass, kernel_size=1, stride=1)

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_loss1 = 0.0
        train_loss2 = 0.0
        train_loss3 = 0.0
        self.model.train()

        # trying to save a checkpoint and check if it exists...
        # import os
        # cur_path = os.path.dirname(os.path.abspath('.'))
        # print('saving mycheckpoint in:' + cur_path )
        # checkpoint_name = 'mycheckpoint.pth.tar'
        # save_path = cur_path + '/' + checkpoint_name
        # torch.save(self.model.module.state_dict(), save_path)
        # assert(os.path.isfile(save_path))
        # # torch.save(self.model.module.state_dict(), checkpoint_name)
        # # assert(os.path.isfile(cur_path + '/' + checkpoint_name))
        # print('checkpoint saved ok')
        # # checkpoint saved

        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        # import pdb; pdb.set_trace()
        # label_w = label_stats(self.train_loader, nimg=70) #** -args.norm_loss if args.norm_loss != 0 else None
        # import pdb; pdb.set_trace()

        for i, sample in enumerate(tbar):
            #print("i is:{}, index is:{}".format(i,sample['index']))
            #print("path is:{}".format(sample['path']))
            #image, target = sample['image'], sample['label']
            image, target, index, path, b_mask, enlarged_b_mask = sample[
                'image'], sample['label'], sample['index'], sample[
                    'path'], sample['b_mask'], sample['enlarged_b_mask']
            #print('sample for training index is :{} and path is:{}'.format(index,path))
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            #not using learning rate scheduler and apply a fixed learning rate
            #self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            #import pdb; pdb.set_trace()
            #loss = self.criterion(output, target, b_mask, enlarged_b_mask)
            loss1, loss2, loss3, loss = self.criterion(output, target, b_mask,
                                                       enlarged_b_mask)
            # criterion = nn.BCELoss()
            # loss = criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss3 += loss3.item()
            #import pdb; pdb.set_trace()
            tbar.set_description('Train loss_total: %.3f' % (train_loss /
                                                             (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss1: %.3f' % (train_loss1 / (i + 1)))
            self.writer.add_scalar('train/total_loss1_iter', loss1.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss2: %.3f' % (train_loss2 / (i + 1)))
            self.writer.add_scalar('train/total_loss2_iter', loss2.item(),
                                   i + num_img_tr * epoch)
            # tbar.set_description('Train loss3: %.3f' % (train_loss3 / (i + 1)))
            self.writer.add_scalar('train/total_loss3_iter', loss3.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                #if i % (num_img_tr // 10000) == 0:   #for the whole dataset
                #if i % (num_img_tr // 10) == 0:    # for debugging
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output, b_mask,
                                             enlarged_b_mask, global_step)
            # Save the model after each 500 iterations
            if i % 500 == 0:  #for the whole dataset
                #if i % 5 == 0:    # for debugging
                is_best = False
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

            # perform the validation after each 1000 iterations
            if i % 300 == 0:
                #if i % 1000 == 0 :  #for the whole dataset
                #if i % 15 == 0 :    # for debugging
                self.validation(i)
                #self.validation(i + num_img_tr * epoch)

            ## garbage collection pass
            #del image, target, index, path, b_mask, enlarged_b_mask, output
            #gc.collect()

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/total_loss1_epoch', train_loss1, epoch)
        self.writer.add_scalar('train/total_loss2_epoch', train_loss2, epoch)
        self.writer.add_scalar('train/total_loss3_epoch', train_loss3, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            #print('saving checkpoint')
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        test_loss1 = 0.0
        test_loss2 = 0.0
        test_loss3 = 0.0
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            #image, target, index, path = sample['image'], sample['label'], sample['index'], sample['path']
            image, target, index, path, b_mask, enlarged_b_mask = sample[
                'image'], sample['label'], sample['index'], sample[
                    'path'], sample['b_mask'], sample['enlarged_b_mask']
            #print('sample for testing index is :{} and path is:{}'.format(index,path))
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            #loss = self.criterion(output, target)
            #loss = self.criterion(output, target, b_mask, enlarged_b_mask)
            loss1, loss2, loss3, loss = self.criterion(output, target, b_mask,
                                                       enlarged_b_mask)
            test_loss += loss.item()
            test_loss1 += loss1.item()
            test_loss2 += loss2.item()
            test_loss3 += loss3.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = np.argmax(target, axis=1)
            #pred = np.argmax(pred, axis=1)
            #import pdb; pdb.set_trace()
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/total_loss1_epoch', test_loss1, epoch)
        self.writer.add_scalar('val/total_loss2_epoch', test_loss2, epoch)
        self.writer.add_scalar('val/total_loss3_epoch', test_loss3, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #26
0
class Trainer(object):
    def __init__(self, args):
        warnings.filterwarnings('ignore')
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(
            args.backbone, args.dataset, args.exp)
        if args.dataset == 'pascal':
            raise NotImplementedError
        elif args.dataset == 'cityscapes':
            kwargs = {
                'num_workers': args.workers,
                'pin_memory': True,
                'drop_last': True
            }
            dataset_loader, num_classes = dataloaders.make_data_loader(
                args, **kwargs)
            args.num_classes = num_classes
        elif args.dataset == 'marsh':
            kwargs = {
                'num_workers': args.workers,
                'pin_memory': True,
                'drop_last': True
            }
            dataset_loader, val_loader, test_loader, num_classes = dataloaders.make_data_loader(
                args, **kwargs)
            args.num_classes = num_classes
        else:
            raise ValueError('Unknown dataset: {}'.format(args.dataset))

        if args.backbone == 'autodeeplab':
            model = Retrain_Autodeeplab(args)
        else:
            raise ValueError('Unknown backbone: {}'.format(args.backbone))

        if args.criterion == 'Ohem':
            args.thresh = 0.7
            args.crop_size = [args.crop_size, args.crop_size] if isinstance(
                args.crop_size, int) else args.crop_size
            args.n_min = int((args.batch_size / len(args.gpu) *
                              args.crop_size[0] * args.crop_size[1]) // 16)
        criterion = build_criterion(args)

        model = nn.DataParallel(model).cuda()
        model.train()
        if args.freeze_bn:
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False
        optimizer = optim.SGD(model.module.parameters(),
                              lr=args.base_lr,
                              momentum=0.9,
                              weight_decay=0.0001)

        max_iteration = len(dataset_loader) * args.epochs
        scheduler = Iter_LR_Scheduler(args, max_iteration, len(dataset_loader))

        start_epoch = 0

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))
                self.best_pred = checkpoint['best_pred']
            else:
                raise ValueError('=> no checkpoint found at {0}'.format(
                    args.resume))
        ##mergee
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        #kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataset_loader, val_loader, test_loader, num_classes

        self.criterion = criterion
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = scheduler
        self.scheduler = LR_Scheduler(
            "poly", args.lr, args.epochs,
            len(self.train_loader))  #removed None from second parameter.

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            cur_iter = epoch * len(self.train_loader) + i
            #self.scheduler(self.optimizer, cur_iter)# this scheduler did not work. let try other one for say 500 epochs.
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #print("I was here!!")
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        # save checkpoint every epoch
        is_best = False
        self.saver.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Classwise_IoU:")
        print(IoU)
        print('Loss: %.3f' % test_loss)
        print(self.evaluator.confusion_matrix)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #27
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Example #28
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass,
                            12,
                            self.criterion,
                            crop_size=self.args.crop_size,
                            lambda_latency=self.args.lambda_latency)
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Define Optimizer

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        self.evaluator_device = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader1))
        self.architect = Architect(self.model, args)

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def fetch_arch(self):
        d = dict()
        d['alphas_cell'] = self.model.arch_parameters()[0]
        d['alphas_network'] = self.model.arch_parameters()[1]
        d['alphas_distributed'] = self.model.arch_parameters()[2]
        return d

    def training(self, epoch):
        train_la, train_loss = 0.0, 0.0
        self.model.train()
        tbar = tqdm(self.train_loader1)
        tbar1 = tqdm(self.train_loader1)
        tbar2 = tqdm(self.train_loader1)
        num_img_tr = len(self.train_loader1)
        for i, sample in enumerate(tbar):

            image, target = sample['image'], sample['label']
            search = next(iter(self.train_loader2))
            image_search, target_search = search['image'], search['label']
            # print ('------------------------begin-----------------------')
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
                #image_search, target_search = image_search.cuda (), target_search.cuda ()
                # print ('cuda finish')
            #if epoch>=20:
            #self.architect.step (image_search, target_search)
        #   if i%20==0:
        #       print(self.model.arch_parameters()[2])
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output, device_output, loss, la, c_loss, d_loss = self.model._loss(
                image, target)
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(
                    self.writer, self.args.dataset, image, target,
                    output * 0.5 + device_output * 0.5, global_step)
            loss.backward()
            self.optimizer.step()
            if epoch >= 20:
                image_search, target_search = image_search.cuda(
                ), target_search.cuda()
                self.architect.step(image_search, target_search)
            train_la += la
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f   Train latency: %.3f' %
                                 (train_loss / (i + 1), train_la / (i + 1)))
            tbar2.set_description(
                'cloud loss:: %.3f   device loss:: %.3f latence loss:: %3f' %
                (c_loss, d_loss, la))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/cloud_loss_iter', c_loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/device_loss_iter', d_loss.item(),
                                   i + num_img_tr * epoch)
            self.writer.add_scalar('train/latency_loss_iter', la.item(),
                                   i + num_img_tr * epoch)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        self.writer.add_scalar('train/latency_loss_epoch', train_la, epoch)
        arch_board = self.model.arch_parameters()[0]
        for i in range(len(arch_board)):
            for j in range(len(arch_board[i])):
                self.writer.add_scalar('cell/' + str(i) + '/' + str(j),
                                       arch_board[i][j], epoch)

        arch_board = self.model.arch_parameters()[1]
        for i in range(len(arch_board)):
            for j in range(len(arch_board[i])):
                for k in range(len(arch_board[i][j])):
                    self.writer.add_scalar(
                        'network/' + str(i) + str(j) + str(k),
                        arch_board[i][j][k], epoch)

        arch_board = self.model.arch_parameters()[2]
        for i in range(len(arch_board)):
            self.writer.add_scalar('distributed/' + str(i), arch_board[i],
                                   epoch)

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch_para': self.fetch_arch(),
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        self.evaluator_device.reset()

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output, device_output, loss, _, _, _ = self.model._loss(
                    image, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred_device = device_output.data.cpu().numpy()

            pred = np.argmax(pred, axis=1)
            pred_device = np.argmax(pred_device, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            self.evaluator_device.add_batch(target, pred_device)
        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU_device = self.evaluator_device.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU_cloud', mIoU, epoch)
        self.writer.add_scalar('val/mIoU_device', mIoU_device, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Example #29
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {"num_workers": args.workers, "pin_memory": True}
        (
            self.train_loader,
            self.val_loader,
            self.test_loader,
            self.nclass,
        ) = make_data_loader(args, **kwargs)

        # Define network
        model = PanopticDeepLab(
            num_classes=self.nclass,
            backbone=args.backbone,
            output_stride=args.out_stride,
            sync_bn=args.sync_bn,
            freeze_bn=args.freeze_bn,
        )

        if args.create_params:
            train_params = [
                {
                    "params": model.get_1x_lr_params(),
                    "lr": args.lr
                },
                {
                    "params": model.get_10x_lr_params(),
                    "lr": args.lr * 10
                },
            ]
        else:
            train_params = model.parameters()

        # Define Optimizer
        optimizer = torch.optim.SGD(
            train_params,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + "_classes_weights.npy",
            )
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = PanopticLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        self.scheduler = ReduceLROnPlateau(optimizer,
                                           mode="max",
                                           factor=0.89,
                                           patience=2,
                                           verbose=True)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            if args.cuda:
                self.model.module.load_state_dict(checkpoint["state_dict"])
            else:
                self.model.load_state_dict(checkpoint["state_dict"])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.best_pred = checkpoint["best_pred"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        semantic_loss_out = 0.0
        center_loss_out = 0.0
        reg_x_loss_out = 0.0
        reg_y_loss_out = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, label, center, x_reg, y_reg = (
                sample["image"],
                sample["label"],
                sample["center"],
                sample["x_reg"],
                sample["y_reg"],
            )
            if self.args.cuda:
                image, label, center, x_reg, y_reg = (
                    image.cuda(),
                    label.cuda(),
                    center.cuda(),
                    x_reg.cuda(),
                    y_reg.cuda(),
                )
            self.optimizer.zero_grad()
            try:
                output = self.model(image)
            except ValueError as identifier:
                # catch error with wrong input size
                print("Error: ", identifier)
                continue
            (
                semantic_loss,
                center_loss,
                reg_x_loss,
                reg_y_loss,
            ) = self.criterion.forward(output, label, center, x_reg, y_reg)

            # total loss
            loss = semantic_loss + center_loss + reg_x_loss + reg_y_loss

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            semantic_loss_out += semantic_loss.item()
            center_loss_out += center_loss.item()
            reg_x_loss_out += reg_x_loss.item()
            reg_y_loss_out += reg_y_loss.item()
            tbar.set_description(
                "Losses -> Train: %.3f, Semantic: %.3f, Center: %.3f, x_reg: %.3f, y_reg: %.3f"
                % (
                    train_loss / (i + 1),
                    semantic_loss_out / (i + 1),
                    center_loss_out / (i + 1),
                    reg_x_loss_out / (i + 1),
                    reg_y_loss_out / (i + 1),
                ))

            self.writer.add_scalar(
                "train/semantic_loss_iter",
                semantic_loss.item(),
                i + num_img_tr * epoch,
            )

            self.writer.add_scalar(
                "train/center_loss_iter",
                center_loss.item(),
                i + num_img_tr * epoch,
            )

            self.writer.add_scalar(
                "train/reg_x_loss_iter",
                reg_x_loss.item(),
                i + num_img_tr * epoch,
            )

            self.writer.add_scalar(
                "train/reg_y_loss_iter",
                reg_y_loss.item(),
                i + num_img_tr * epoch,
            )

            self.writer.add_scalar("train/total_loss_iter", loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(
                    self.writer,
                    self.args.dataset,
                    image,
                    label,
                    output[0],
                    global_step,
                    centers=output[1],
                    reg_x=output[2],
                    reg_y=output[3],
                )

        self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Loss: %.3f" % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc="\r")
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, label, center, x_reg, y_reg = (
                sample["image"],
                sample["label"],
                sample["center"],
                sample["x_reg"],
                sample["y_reg"],
            )
            if self.args.cuda:
                image, label, center, x_reg, y_reg = (
                    image.cuda(),
                    label.cuda(),
                    center.cuda(),
                    x_reg.cuda(),
                    y_reg.cuda(),
                )
            with torch.no_grad():
                try:
                    output = self.model(image)
                except ValueError as identifier:
                    # catch error with wrong input size
                    print("Error: ", identifier)
                    continue

            (
                semantic_loss,
                center_loss,
                reg_x_loss,
                reg_y_loss,
            ) = self.criterion.forward(output, label, center, x_reg, y_reg)

            # total loss
            loss = semantic_loss + center_loss + reg_x_loss + reg_y_loss
            test_loss += loss.item()
            tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
            pred = output[0].data.cpu().numpy()
            label = label.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(label, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
        self.writer.add_scalar("val/mIoU", mIoU, epoch)
        self.writer.add_scalar("val/Acc", Acc, epoch)
        self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
        self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
        print("Validation:")
        print("[Epoch: %d, numImages: %5d]" %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print("Loss: %.3f" % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": self.model.module.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "best_pred": self.best_pred,
                },
                is_best,
            )
Example #30
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval

    #Note: self.start_epoch and self.epochs are only used in run() to schedule training & validation
    def run(self):
        if self.init_eval:  #init with an evaluation
            init_test_epoch = self.start_epoch - 1
            Acc, _, mIoU, _ = self.validate(init_test_epoch, save=True)
            self.writer.add_scalar('eval/Acc', Acc, init_test_epoch)
            self.writer.add_scalar('eval/mIoU', mIoU, init_test_epoch)
            self.writer.flush()
        end_epoch = self.start_epoch + self.epochs
        for epoch in range(self.start_epoch, end_epoch):
            loss = self.train(epoch)
            self.writer.add_scalar(
                'train/lr',
                self.optimizer.state_dict()['param_groups'][0]['lr'], epoch)
            self.writer.add_scalar('train/loss', loss, epoch)
            self.writer.flush()
            saved_dict = {
                'model': self.model.__class__.__name__,
                'epoch': epoch,
                'dataset': self.dataset,
                'parameters': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }
            torch.save(
                saved_dict,
                f'./{self.model.__class__.__name__}_{self.dataset}_epoch{epoch}.pth.tar'
            )

            Acc, _, mIoU, _ = self.validate(epoch, save=True)
            self.writer.add_scalar('eval/Acc', Acc, epoch)
            self.writer.add_scalar('eval/mIoU', mIoU, epoch)
            self.writer.flush()
            if self.schedule_mode == 'step' or self.schedule_mode == 'poly':
                self.scheduler.step()
            elif self.schedule_mode == 'miou':
                self.scheduler.step(mIoU)
            elif self.schedule_mode == 'acc':
                self.scheduler.step(Acc)
            else:
                raise NotImplementedError
        self.writer.close()

    def train(self, epoch):
        self.model.train()
        print(f"----------epoch {epoch}----------")
        print("lr:", self.optimizer.state_dict()['param_groups'][0]['lr'])
        total_loss = 0
        num_of_batches = len(self.train_loader) * self.epoch_repeat
        for itr in range(100):
            for i, [img, gt] in enumerate(self.train_loader):
                print(
                    f"epoch: {epoch} batch: {i+1+itr*len(self.train_loader)}/{num_of_batches}"
                )
                print("img:", img.shape)
                print("gt:", gt.shape)
                self.optimizer.zero_grad()
                if self.cuda:
                    img, gt = img.cuda(), gt.cuda()
                pred = self.model(img)
                print("pred:", pred.shape)
                loss = self.criterion(pred, gt.long())
                print("loss:", loss)
                total_loss += loss.data
                loss.backward()
                self.optimizer.step()
        return total_loss

    def validate(self, epoch, save):
        self.model.eval()
        print(f"----------validate epoch {epoch}----------")
        if save and not os.path.exists("epoch_" + str(epoch)):
            os.mkdir("epoch" + str(epoch))
        num_of_imgs = len(self.eval_loader)
        for i, sample in enumerate(self.eval_loader):
            img_name, gt_name = sample['img'][0], sample['gt'][0]
            print(f"{i+1}/{num_of_imgs}:")

            img = Image.open(img_name).convert('RGB')
            gt = np.array(Image.open(gt_name))
            times, points = self.get_pointset(img)
            print(f'{times} tests will be carried out on {img_name}...')
            W, H = img.size  #TODO: check numpy & PIL dimensions
            label_map = np.zeros([H, W], dtype=np.uint8)
            score_map = np.zeros([H, W], dtype=np.uint8)
            #score_map not necessarily to be uint8 but uint8 gets better result...
            tbar = tqdm(points)
            for i, j in tbar:
                tbar.set_description(f"{i},{j}")
                label_map, score_map = self.test_patch(i, j, img, label_map,
                                                       score_map)
            #finish a large
            self.evaluator.add_batch(label_map, gt)
            if save:
                mask = ret2mask(label_map, dataset=self.dataset)
                png_name = os.path.join(
                    "epoch" + str(epoch),
                    os.path.basename(img_name).split('.')[0] + '.png')
                Image.fromarray(mask).save(png_name)
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc:", Acc)
        print("Acc_class:", Acc_class)
        print("mIoU:", mIoU)
        print("FWIoU:", FWIoU)
        self.evaluator.reset()
        return Acc, Acc_class, mIoU, FWIoU

    def test_patch(self, i, j, img, label_map, score_map):
        tr = EvaluationTransform(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        #print(img.size)
        cropped = img.crop(
            (i, j, i + self.eval_crop_size, j + self.eval_crop_size))
        cropped = tr(cropped).unsqueeze(0)
        if self.cuda:
            cropped = cropped.cuda()
        out = self.model(cropped)
        #out = torch.nn.functional.softmax(out, dim=1)
        ret = torch.max(out.squeeze(), dim=0)
        score = ret[0].data.detach().cpu().numpy()
        label = ret[1].data.detach().cpu().numpy()

        #numpy array's shape is [H,W] while PIL.Image is [W,H]
        score_temp = score_map[j:j + self.eval_crop_size,
                               i:i + self.eval_crop_size]
        label_temp = label_map[j:j + self.eval_crop_size,
                               i:i + self.eval_crop_size]
        index = score > score_temp
        score_temp[index] = score[index]
        label_temp[index] = label[index]
        label_map[j:j + self.eval_crop_size,
                  i:i + self.eval_crop_size] = label_temp
        score_map[j:j + self.eval_crop_size,
                  i:i + self.eval_crop_size] = score_temp

        return label_map, score_map

    def get_pointset(self, img):
        W, H = img.size
        pointset = []
        count = 0
        i = 0
        while i < W:
            break_flag_i = False
            if i + self.eval_crop_size >= W:
                i = W - self.eval_crop_size
                break_flag_i = True
            j = 0
            while j < H:
                break_flag_j = False
                if j + self.eval_crop_size >= H:
                    j = H - self.eval_crop_size
                    break_flag_j = True
                count += 1
                pointset.append((i, j))
                if break_flag_j:
                    break
                j += self.stride
            if break_flag_i:
                break
            i += self.stride
        value = get_test_times(W, H, self.eval_crop_size, self.stride)
        assert count == value, f'count={count} while get_test_times returns {value}'
        return count, pointset