def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        self.model = None
        # Define network
        if self.args.backbone == 'unet':
            self.model = UNet(in_channels=4, n_classes=self.nclass)
            print("using UNet")
        if self.args.backbone == 'unetNested':
            self.model = UNetNested(in_channels=4, n_classes=self.nclass)
            print("using UNetNested")

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        if not os.path.isfile(args.checkpoint_file):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                args.checkpoint_file))
        checkpoint = torch.load(args.checkpoint_file)

        self.model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'".format(args.checkpoint_file))
    def __init__(self, args):
        self.args = args

        self.nclass = 16
        # Define network
        self.unet_model = UNet(in_channels=4, n_classes=self.nclass)
        self.unetNested_model = UNetNested(in_channels=4,
                                           n_classes=self.nclass)
        self.combine_net_model = CombineNet(in_channels=192,
                                            n_classes=self.nclass)

        # Using cuda
        if args.cuda:
            self.unet_model = self.unet_model.cuda()
            self.unetNested_model = self.unetNested_model.cuda()
            self.combine_net_model = self.combine_net_model.cuda()

        # Load Unet model
        if not os.path.isfile(args.unet_checkpoint_file):
            raise RuntimeError("=> no unet checkpoint found at '{}'".format(
                args.unet_checkpoint_file))
        checkpoint = torch.load(args.unet_checkpoint_file)
        self.unet_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded unet checkpoint '{}'".format(
            args.unet_checkpoint_file))

        # Load UNetNested model
        if not os.path.isfile(args.unetNested_checkpoint_file):
            raise RuntimeError(
                "=> no UNetNested checkpoint found at '{}'".format(
                    args.unetNested_checkpoint_file))
        checkpoint = torch.load(args.unetNested_checkpoint_file)
        self.unetNested_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded UNetNested checkpoint '{}'".format(
            args.unetNested_checkpoint_file))

        # Load Combine Net
        if not os.path.isfile(args.combine_net_checkpoint_file):
            raise RuntimeError(
                "=> no combine net checkpoint found at '{}'".format(
                    args.combine_net_checkpoint_file))
        checkpoint = torch.load(args.combine_net_checkpoint_file)
        self.combine_net_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded combine net checkpoint '{}'".format(
            args.combine_net_checkpoint_file))
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        model = None
        # Define network
        if self.args.backbone == 'unet':
            model = UNet(in_channels=4,
                         n_classes=self.nclass,
                         sync_bn=args.sync_bn)
            print("using UNet")
        if self.args.backbone == 'unetNested':
            model = UNetNested(in_channels=4,
                               n_classes=self.nclass,
                               sync_bn=args.sync_bn)
            print("using UNetNested")

        # train_params = [{'params': model.get_params(), 'lr': args.lr}]
        train_params = [{'params': model.get_params()}]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(train_params,
                                     self.args.learn_rate,
                                     weight_decay=args.weight_decay,
                                     amsgrad=True)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.nclass = 16
        # Define network
        self.unet_model = UNet(in_channels=4, n_classes=self.nclass)
        self.unetNested_model = UNetNested(in_channels=4,
                                           n_classes=self.nclass)
        self.combine_net_model = CombineNet(in_channels=192,
                                            n_classes=self.nclass)

        train_params = [{'params': self.combine_net_model.get_params()}]
        # Define Optimizer
        self.optimizer = torch.optim.Adam(train_params,
                                          self.args.learn_rate,
                                          weight_decay=args.weight_decay,
                                          amsgrad=True)

        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.unet_model = self.unet_model.cuda()
            self.unetNested_model = self.unetNested_model.cuda()
            self.combine_net_model = self.combine_net_model.cuda()

        # Load Unet checkpoint
        if not os.path.isfile(args.unet_checkpoint_file):
            raise RuntimeError("=> no Unet checkpoint found at '{}'".format(
                args.unet_checkpoint_file))
        checkpoint = torch.load(args.unet_checkpoint_file)
        self.unet_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded Unet checkpoint '{}'".format(
            args.unet_checkpoint_file))

        # Load UNetNested checkpoint
        if not os.path.isfile(args.unetNested_checkpoint_file):
            raise RuntimeError(
                "=> no UNetNested checkpoint found at '{}'".format(
                    args.unetNested_checkpoint_file))
        checkpoint = torch.load(args.unetNested_checkpoint_file)
        self.unetNested_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded UNetNested checkpoint '{}'".format(
            args.unetNested_checkpoint_file))

        # Resuming combineNet checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    "=> no combineNet checkpoint found at '{}'".format(
                        args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.combine_net_model.module.load_state_dict(
                    checkpoint['state_dict'])
            else:
                self.combine_net_model.load_state_dict(
                    checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded combineNet checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.nclass = 16
        # Define network
        self.unet_model = UNet(in_channels=4, n_classes=self.nclass)
        self.unetNested_model = UNetNested(in_channels=4,
                                           n_classes=self.nclass)
        self.combine_net_model = CombineNet(in_channels=192,
                                            n_classes=self.nclass)

        train_params = [{'params': self.combine_net_model.get_params()}]
        # Define Optimizer
        self.optimizer = torch.optim.Adam(train_params,
                                          self.args.learn_rate,
                                          weight_decay=args.weight_decay,
                                          amsgrad=True)

        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if args.cuda:
            self.unet_model = self.unet_model.cuda()
            self.unetNested_model = self.unetNested_model.cuda()
            self.combine_net_model = self.combine_net_model.cuda()

        # Load Unet checkpoint
        if not os.path.isfile(args.unet_checkpoint_file):
            raise RuntimeError("=> no Unet checkpoint found at '{}'".format(
                args.unet_checkpoint_file))
        checkpoint = torch.load(args.unet_checkpoint_file)
        self.unet_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded Unet checkpoint '{}'".format(
            args.unet_checkpoint_file))

        # Load UNetNested checkpoint
        if not os.path.isfile(args.unetNested_checkpoint_file):
            raise RuntimeError(
                "=> no UNetNested checkpoint found at '{}'".format(
                    args.unetNested_checkpoint_file))
        checkpoint = torch.load(args.unetNested_checkpoint_file)
        self.unetNested_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded UNetNested checkpoint '{}'".format(
            args.unetNested_checkpoint_file))

        # Resuming combineNet checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError(
                    "=> no combineNet checkpoint found at '{}'".format(
                        args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.combine_net_model.module.load_state_dict(
                    checkpoint['state_dict'])
            else:
                self.combine_net_model.load_state_dict(
                    checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded combineNet checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        print('[Epoch: %d, previous best = %.4f]' % (epoch, self.best_pred))
        train_loss = 0.0
        self.combine_net_model.train()
        self.evaluator.reset()
        num_img_tr = len(train_files)
        tbar = tqdm(train_files, desc='\r')

        for i, filename in enumerate(tbar):
            image = Image.open(os.path.join(train_dir, filename))
            label = Image.open(
                os.path.join(
                    train_label_dir,
                    os.path.basename(filename)[:-4] + '_labelTrainIds.png'))
            label = np.array(label).astype(np.float32)
            label = label.reshape((1, 400, 400))
            label = torch.from_numpy(label).float()
            label = label.cuda()

            # UNet_multi_scale_predict
            unt_pred = self.unet_multi_scale_predict(image)

            # UNetNested_multi_scale_predict
            unetnested_pred = self.unetnested_multi_scale_predict(image)

            net_input = torch.cat([unt_pred, unetnested_pred], 1)

            self.optimizer.zero_grad()
            output = self.combine_net_model(net_input)
            loss = self.criterion(output, label)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

        pred = output.data.cpu().numpy()
        label = label.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        # Add batch sample into evaluator
        self.evaluator.add_batch(label, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('train/mIoU', mIoU, epoch)
        self.writer.add_scalar('train/Acc', Acc, epoch)
        self.writer.add_scalar('train/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('train/fwIoU', FWIoU, epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        print('train validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % train_loss)
        print('---------------------------------')

    def validation(self, epoch):
        test_loss = 0.0
        self.combine_net_model.eval()
        self.evaluator.reset()
        tbar = tqdm(val_files, desc='\r')
        num_img_val = len(val_files)

        for i, filename in enumerate(tbar):
            image = Image.open(os.path.join(val_dir, filename))
            label = Image.open(
                os.path.join(
                    val_label_dir,
                    os.path.basename(filename)[:-4] + '_labelTrainIds.png'))
            label = np.array(label).astype(np.float32)
            label = label.reshape((1, 400, 400))
            label = torch.from_numpy(label).float()
            label = label.cuda()

            # UNet_multi_scale_predict
            unt_pred = self.unet_multi_scale_predict(image)

            # UNetNested_multi_scale_predict
            unetnested_pred = self.unetnested_multi_scale_predict(image)

            net_input = torch.cat([unt_pred, unetnested_pred], 1)

            with torch.no_grad():
                output = self.combine_net_model(net_input)
            loss = self.criterion(output, label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1)))
            self.writer.add_scalar('val/total_loss_iter', loss.item(),
                                   i + num_img_val * epoch)
            pred = output.data.cpu().numpy()
            label = label.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(label, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('test validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        print('====================================')

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.combine_net_model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def unet_multi_scale_predict(self, image_ori: Image):
        self.unet_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unet_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unet_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    def unet_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unet_model(img)
        return output

    def unetnested_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unetNested_model(img)
        return output

    def unetnested_multi_scale_predict(self, image_ori: Image):
        self.unetNested_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unetnested_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unetnested_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    @staticmethod
    def transform_test(img):
        # Normalize
        mean = (0.544650, 0.352033, 0.384602, 0.352311)
        std = (0.249456, 0.241652, 0.228824, 0.227583)
        img = np.array(img).astype(np.float32)
        img /= 255.0
        img -= mean
        img /= std
        # ToTensor
        img = img.transpose((2, 0, 1))
        img = np.array([
            img,
        ])
        img = torch.from_numpy(img).float()
        return img
class Visualization:
    def __init__(self, args):
        self.args = args

        self.nclass = 16
        # Define network
        self.unet_model = UNet(in_channels=4, n_classes=self.nclass)
        self.unetNested_model = UNetNested(in_channels=4,
                                           n_classes=self.nclass)
        self.combine_net_model = CombineNet(in_channels=192,
                                            n_classes=self.nclass)

        # Using cuda
        if args.cuda:
            self.unet_model = self.unet_model.cuda()
            self.unetNested_model = self.unetNested_model.cuda()
            self.combine_net_model = self.combine_net_model.cuda()

        # Load Unet model
        if not os.path.isfile(args.unet_checkpoint_file):
            raise RuntimeError("=> no unet checkpoint found at '{}'".format(
                args.unet_checkpoint_file))
        checkpoint = torch.load(args.unet_checkpoint_file)
        self.unet_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded unet checkpoint '{}'".format(
            args.unet_checkpoint_file))

        # Load UNetNested model
        if not os.path.isfile(args.unetNested_checkpoint_file):
            raise RuntimeError(
                "=> no UNetNested checkpoint found at '{}'".format(
                    args.unetNested_checkpoint_file))
        checkpoint = torch.load(args.unetNested_checkpoint_file)
        self.unetNested_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded UNetNested checkpoint '{}'".format(
            args.unetNested_checkpoint_file))

        # Load Combine Net
        if not os.path.isfile(args.combine_net_checkpoint_file):
            raise RuntimeError(
                "=> no combine net checkpoint found at '{}'".format(
                    args.combine_net_checkpoint_file))
        checkpoint = torch.load(args.combine_net_checkpoint_file)
        self.combine_net_model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded combine net checkpoint '{}'".format(
            args.combine_net_checkpoint_file))

    def visualization(self):
        self.combine_net_model.eval()
        tbar = tqdm(test_files, desc='\r')

        for i, filename in enumerate(tbar):
            image = Image.open(os.path.join(test_dir, filename))

            # UNet_multi_scale_predict
            unt_pred = self.unet_multi_scale_predict(image)

            # UNetNested_multi_scale_predict
            unetnested_pred = self.unetnested_multi_scale_predict(image)

            net_input = torch.cat([unt_pred, unetnested_pred], 1)

            with torch.no_grad():
                output = self.combine_net_model(net_input)
            pred = output.data.cpu().numpy()[0]
            pred = np.argmax(pred, axis=0)

            rgb = decode_segmap(pred, self.args.dataset)

            pred_img = Image.fromarray(pred, mode='L')
            rgb_img = Image.fromarray(rgb, mode='RGB')

            pred_img.save(
                os.path.join(self.args.vis_logdir, 'raw_train_id', filename))
            rgb_img.save(
                os.path.join(self.args.vis_logdir, 'vis_color', filename))

    def unet_multi_scale_predict(self, image_ori: Image):
        self.unet_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unet_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unet_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unet_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    def unet_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unet_model(img)
        return output

    def unetnested_predict(self, img: Image) -> torch.Tensor:
        img = self.transform_test(img)
        if self.args.cuda:
            img = img.cuda()
        with torch.no_grad():
            output = self.unetNested_model(img)
        return output

    def unetnested_multi_scale_predict(self, image_ori: Image):
        self.unetNested_model.eval()

        # 预测原图
        sample_ori = image_ori.copy()
        output_ori = self.unetnested_predict(sample_ori)

        # 预测旋转三个
        angle_list = [90, 180, 270]
        for angle in angle_list:
            img_rotate = image_ori.rotate(angle, Image.BILINEAR)
            output = self.unetnested_predict(img_rotate)
            pred = output.data.cpu().numpy()[0]
            pred = pred.transpose((1, 2, 0))
            m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1)
            pred = cv2.warpAffine(pred, m_rotate, (400, 400))
            pred = pred.transpose((2, 0, 1))
            output = torch.from_numpy(np.array([
                pred,
            ])).float()
            output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测竖直翻转
        img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 0)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        # 预测水平翻转
        img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT)
        output = self.unetnested_predict(img_flip)
        pred = output.data.cpu().numpy()[0]
        pred = pred.transpose((1, 2, 0))
        pred = cv2.flip(pred, 1)
        pred = pred.transpose((2, 0, 1))
        output = torch.from_numpy(np.array([
            pred,
        ])).float()
        output_ori = torch.cat([output_ori, output.cuda()], 1)

        return output_ori

    @staticmethod
    def transform_test(img):
        # Normalize
        mean = (0.544650, 0.352033, 0.384602, 0.352311)
        std = (0.249456, 0.241652, 0.228824, 0.227583)
        img = np.array(img).astype(np.float32)
        img /= 255.0
        img -= mean
        img /= std
        # ToTensor
        img = img.transpose((2, 0, 1))
        img = np.array([
            img,
        ])
        img = torch.from_numpy(img).float()
        return img
Ejemplo n.º 7
0
@function: 读取单独保存的模型参数,将其与模型结构一起重新保存
@author:HuiYi or 会意
@file: vis.py.py
@time: 2019/7/30 下午7:00
"""
import torch
from models.backbone.UNet import UNet

model_path_list = [
    '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar',
    '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/checkpoint.pth.tar',
    '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_2/checkpoint.pth.tar'
]

if __name__ == '__main__':
    model = UNet(in_channels=4, n_classes=16, sync_bn=False)
    model = model.cuda()
    param = '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar'
    checkpoint = torch.load(param)
    model.load_state_dict(checkpoint['state_dict'])
    torch.save(
        model,
        '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/model_and_param.pth.tar'
    )
    print('save finish')

    # load
    # model = torch.load('/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/model_and_param.pth.tar')
    # params = model.state_dict()
    # print('load')