def __init__(self):

        #GPU assignment
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        #Load checkpoint
        self.checkpoint = torch.load(
            os.path.join("./src/deeplab_ros/data/model_best.pth.tar"))

        #Load Model
        self.model = DeepLab(num_classes=4,
                             backbone='mobilenet',
                             output_stride=16,
                             sync_bn=True,
                             freeze_bn=False)

        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model = self.model.to(self.device)

        #ROS init
        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw",
                                          ImageMsg,
                                          self.callback,
                                          queue_size=1,
                                          buff_size=2**24)
        self.image_pub = rospy.Publisher("segmentation_image",
                                         ImageMsg,
                                         queue_size=1)
Пример #2
0
class RAN():
    def __init__(self, weight, gpu_ids):
        self.model = DeepLab(num_classes=2,
                             backbone='mobilenet',
                             output_stride=16)

        torch.cuda.set_device(gpu_ids)
        self.model = self.model.cuda()

        assert weight is not None
        if not os.path.isfile(weight):
            raise RuntimeError("=> no checkpoint found at '{}'".format(weight))
        checkpoint = torch.load(weight)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()

        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

    def inference(self, img):
        # normalize
        img = cv2.resize(img, (480, 480))
        img = img.astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std
        img = img.transpose((2, 0, 1))
        img = img[np.newaxis, :, :, :]
        # to tensor
        img = torch.from_numpy(img).float().cuda()

        with torch.no_grad():
            output = self.model(img)
        return output
Пример #3
0
def load_model(model_path, num_classes=14, backbone='resnet', output_stride=16):
    print(f"Loading model from {model_path}")
    model = DeepLab(num_classes=num_classes,
                    backbone=backbone,
                    output_stride=output_stride)

    pretrained_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys and mismatching sizes
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}

    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("Load model in  ", torch.cuda.device_count(), " GPUs!")
        model = nn.DataParallel(model)
    model.to(device)
    model.eval()


    return model
Пример #4
0
    def __init__(self, args):
        super(MPNet, self).__init__()
        self.enable_mp_layer = (args.mpnet_mrf_mode
                                in ['TRWP', 'ISGMR', 'MeanField', 'SGM'])
        self.args = args
        BatchNorm = SynchronizedBatchNorm2d if args.sync_bn else nn.BatchNorm2d
        self.enable_score_scale = args.enable_score_scale

        self.deeplab = DeepLab(
            num_classes=args.n_classes,
            backbone=args.deeplab_backbone,
            output_stride=args.deeplab_outstride,
            sync_bn=args.deeplab_sync_bn,
            freeze_bn=args.deeplab_freeze_bn,
            enable_interpolation=args.deeplab_enable_interpolation,
            pretrained_path=args.resnet_pretrained_path,
            norm_layer=BatchNorm,
            enable_aspp=not self.args.disable_aspp)

        if self.enable_mp_layer:
            if self.args.mpnet_mrf_mode == 'TRWP':
                self.mp_layer = MPModule_TRWP(self.args,
                                              enable_create_label_context=True,
                                              enable_saving_label=False)
            elif self.args.mpnet_mrf_mode in {'ISGMR', 'SGM'}:
                self.mp_layer = MPModule_ISGMR(
                    self.args,
                    enable_create_label_context=True,
                    enable_saving_label=False)
            elif self.args.mpnet_mrf_mode == 'MeanField':
                self.mp_layer = MeanField(self.args,
                                          enable_create_label_context=True)
            else:
                assert False
Пример #5
0
    def __init__(self, weight, gpu_ids):
        self.model = DeepLab(num_classes=2,
                             backbone='mobilenet',
                             output_stride=16)

        torch.cuda.set_device(gpu_ids)
        self.model = self.model.cuda()

        assert weight is not None
        if not os.path.isfile(weight):
            raise RuntimeError("=> no checkpoint found at '{}'".format(weight))
        checkpoint = torch.load(weight)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()

        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
Пример #6
0
def main():
    args = arguments()
    seed(args)

    model = DeepLab(backbone='mobilenet',
                    output_stride=16,
                    num_classes=21,
                    sync_bn=False)
    model.eval()

    from aimet_torch import batch_norm_fold
    from aimet_torch import utils
    args.input_shape = (1, 3, 513, 513)
    batch_norm_fold.fold_all_batch_norms(model, args.input_shape)
    utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6,
                                              torch.nn.ReLU)

    if args.checkpoint_path:
        model.load_state_dict(torch.load(args.checkpoint_path))
    else:
        raise ValueError('checkpoint path {} must be specified'.format(
            args.checkpoint_path))

    data_loader_kwargs = {'worker_init_fn': work_init, 'num_workers': 0}
    train_loader, val_loader, test_loader, num_class = make_data_loader(
        args, **data_loader_kwargs)
    eval_func_quant = model_eval(args, val_loader)
    eval_func = model_eval(args, val_loader)

    from aimet_common.defs import QuantScheme
    from aimet_torch.quantsim import QuantizationSimModel
    if hasattr(args, 'quant_scheme'):
        if args.quant_scheme == 'range_learning_tf':
            quant_scheme = QuantScheme.training_range_learning_with_tf_init
        elif args.quant_scheme == 'range_learning_tfe':
            quant_scheme = QuantScheme.training_range_learning_with_tf_enhanced_init
        elif args.quant_scheme == 'tf':
            quant_scheme = QuantScheme.post_training_tf
        elif args.quant_scheme == 'tf_enhanced':
            quant_scheme = QuantScheme.post_training_tf_enhanced
        else:
            raise ValueError("Got unrecognized quant_scheme: " +
                             args.quant_scheme)
        kwargs = {
            'quant_scheme': quant_scheme,
            'default_param_bw': args.default_param_bw,
            'default_output_bw': args.default_output_bw,
            'config_file': args.config_file
        }
    print(kwargs)
    sim = QuantizationSimModel(model.cpu(),
                               input_shapes=args.input_shape,
                               **kwargs)
    sim.compute_encodings(eval_func_quant, (1024, True))
    post_quant_top1 = eval_func(sim.model.cuda(), (99999999, True))
    print("Post Quant mIoU :", post_quant_top1)
def main(checkpoint_filename, input_image, output_image):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Define network
    model = DeepLab(num_classes=3,
                    backbone='resnet',
                    output_stride=16,
                    sync_bn=False,
                    freeze_bn=False)

    checkpoint = torch.load(checkpoint_filename, map_location=device)
    state_dict = checkpoint['state_dict']
    # because model was saved with DataParallel, stored checkpoint contains "module" prefix that we want to strip
    state_dict = {
        key[7:] if key.startswith('module.') else key: val
        for key, val in state_dict.items()
    }
    model.load_state_dict(state_dict)
    model.eval()

    image = Image.open(input_image).convert('RGB')
    mask = predict(model, image)

    mask.save(output_image)
Пример #8
0
def create_segmentation_models(encoder,
                               arch,
                               num_classes=4,
                               encoder_weights=None,
                               activation=None):
    '''
    segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch
    has following architectures: 
    - Unet
    - Linknet
    - FPN
    - PSPNet
    encoders: A lot! see the above github page.

    Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception
    has for encoders:
    - resnet (resnet101)
    - mobilenet 
    - xception
    - drn
    '''
    if arch == "Unet":
        return smp.Unet(encoder,
                        encoder_weights=encoder_weights,
                        classes=num_classes,
                        activation=activation)
    elif arch == "Linknet":
        return smp.Linknet(encoder,
                           encoder_weights=encoder_weghts,
                           classes=num_classes,
                           activation=activation)
    elif arch == "FPN":
        return smp.FPN(encoder,
                       encoder_weights=encoder_weghts,
                       classes=num_classes,
                       activation=activation)
    elif arch == "PSPNet":
        return smp.PSPNet(encoder,
                          encoder_weights=encoder_weghts,
                          classes=num_classes,
                          activation=activation)
    elif arch == "deeplabv3plus":
        if deeplabv3plus_PATH in os.environ:
            sys.path.append(os.environ[deeplabv3plus_PATH])
            from modeling.deeplab import DeepLab
            return DeepLab(encoder, num_classes=4)
        else:
            raise ValueError('Set deeplabv3plus path by environment variable.')
    else:
        raise ValueError(
            'arch {} is not found, set the correct arch'.format(arch))
        sys.exit()
Пример #9
0
def inference_A_sample_image(img_path, model_path, num_classes, backbone,
                             output_stride, sync_bn, freeze_bn):

    # read image
    image = cv2.imread(img_path)

    # print(image.shape)
    image = np.array(image).astype(np.float32)
    # Normalize pascal image (mean and std is from pascal.py)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    image /= 255
    image -= mean
    image /= std

    # swap color axis because
    # numpy image: H x W x C
    # torch image: C X H X W
    image = image.transpose((2, 0, 1))

    # to 4D, N=1
    image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])
    image = torch.from_numpy(image)  #.float()

    model = DeepLab(num_classes=num_classes,
                    backbone=backbone,
                    output_stride=output_stride,
                    sync_bn=sync_bn,
                    freeze_bn=freeze_bn,
                    pretrained=True)  # False

    if torch.cuda.is_available() is False:
        device = torch.device('cpu')
    else:
        device = None  # need added

    # checkpoint = torch.load(model_path,map_location=device)
    # model.load_state_dict(checkpoint['state_dict'])
    checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location=device)
    model.load_state_dict(checkpoint['state_dict'])

    # for set dropout and batch normalization layers to evaluation mode before running inference.
    #  Failing to do this will yield inconsistent inference results.
    model.eval()

    with torch.no_grad():
        output = model(image)

        out_np = output.cpu().data.numpy()

        pred = np.argmax(out_np, axis=1)

        pred = pred.reshape(pred.shape[1], pred.shape[2])

        # save result
        cv2.imwrite('output.jpg', pred)

        test = 1
Пример #10
0
def main(args):
    vali_dataset = MRIBrainSegmentation(root_folder=args.root_folder,
                                        image_label=args.data_label,
                                        is_train=False)
    vali_loader = torch.utils.data.DataLoader(vali_dataset, batch_size=16, shuffle=False,
                                              num_workers=4, drop_last=False)

    # Init and load model
    model = DeepLab(num_classes=1,
                    backbone='resnet',
                    output_stride=8,
                    sync_bn=None,
                    freeze_bn=False)

    checkpoint = torch.load(args.checkpoint)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for i, sample in enumerate(vali_loader):
            print(i)
            data = sample['image']
            target = sample['mask']
            data, target = data.to(device), target.to(device)
            output = model(data)

            target = target.data.cpu().numpy()
            data = data.data.cpu().numpy()
            output = output.data.cpu().numpy()
            pred = np.zeros_like(output)
            pred[output > 0.5] = 1
            pred = pred[:, 0]
            for j in range(len(target)):
                output_image = pred[j] * 255
                target_image = target[j] * 255

                cv2.imwrite("{}/{:06d}_{:06d}_predict.png".format(args.output_folder, i, j), output_image.astype(np.uint8))
                cv2.imwrite("{}/{:06d}_{:06d}_target.png".format(args.output_folder, i, j), target_image.astype(np.uint8))
                img = data[j].transpose([1, 2, 0])
                img *= (0.229, 0.224, 0.225)
                img += (0.485, 0.456, 0.406)
                img *= 255.0
                cv2.imwrite(
                    "{}}/{:06d}_{:06d}_origin.png".format(args.output_folder,
                        i, j), img.astype(np.uint8))
Пример #11
0
def test(args):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    _, val_loader, _, nclass = make_data_loader(args, **kwargs)

    checkpoint = torch.load(args.ckpt)
    if checkpoint is None:
        raise ValueError

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = DeepLab(num_classes=nclass,
                    backbone='resnet',
                    output_stride=16,
                    sync_bn=True,
                    freeze_bn=False)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to(device)
    torch.set_grad_enabled(False)

    tbar = tqdm(val_loader)
    num_img_tr = len(val_loader)
    for i, sample in enumerate(tbar):
        x1, x2, y1, y2 = [
            int(item) for item in sample['img_meta']['bbox_coord']
        ]  # bbox coord
        w, h = x2 - x1, y2 - y1
        img = sample['img_meta']['image'].squeeze().cpu().numpy()
        img_w, img_h = img.shape[:2]

        inputs = sample['image'].cuda()
        output = model(inputs).squeeze().cpu().numpy()
        pred = np.argmax(output, axis=0)
        result = decode_segmap(pred, dataset=args.dataset, plot=False)

        result = imresize(result, (w, h))
        result_padding = np.zeros(img.shape, dtype=np.uint8)
        result_padding[y1:y2, x1:x2] = result
        result = img // 2 + result_padding * 127
        result[result > 255] = 255
        plt.imsave(
            os.path.join('run', args.dataset, 'deeplab-resnet', 'output',
                         str(i)), result)
Пример #12
0
    def __init__(self, config: BaseConfig):
        self._config = config
        self._model = DeepLab(num_classes=9, output_stride=8,
                              sync_bn=False).to(self._config.device)
        self._border_loss = TotalLoss(self._config)
        self._direction_loss = CrossEntropyLoss()
        self._loaders = get_data_loaders(config)
        self._writer = SummaryWriter()
        self._optimizer = torch.optim.SGD(self._model.parameters(),
                                          lr=self._config.lr,
                                          weight_decay=1e-4,
                                          nesterov=True,
                                          momentum=0.9)
        self._scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self._optimizer, gamma=0.97)

        if self._config.parallel:
            self._model = DistributedDataParallel(self._model,
                                                  device_ids=[
                                                      self._config.device,
                                                  ])
Пример #13
0
def test(args):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
    model = DeepLab(num_classes=nclass,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=False)
    model.load_state_dict(torch.load(args.pretrained, map_location=device)['state_dict'])
    model.eval()
    tbar = tqdm(test_loader) ## train test dev
    for i, sample in enumerate(tbar):
        image, target = sample['image'], sample['label']
        # original_image = image
        if args.use_mixup:
            image, targets_a, targets_b, lam = mixup_data(image, target,
                                                          args.mixup_alpha, use_cuda=False)
        # mixed_image = image
        # image = norm(image.permute(0,2,3,1)).permute(0,3,1,2)
        output = model(image)
Пример #14
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)
    model.eval()

    num_classes = 20
    tp_list = [0] * num_classes
    fp_list = [0] * num_classes
    fn_list = [0] * num_classes
    iou_list = [0] * num_classes

    hist = np.zeros((21, 21))
    group = 1
    scorer = SegScorer(num_classes=21)
    datalayer = SSDatalayer(group)
    cos_similarity_func = nn.CosineSimilarity()
    for count in tqdm(range(1000)):
        dat = datalayer.dequeue()
        ref_img = dat['second_img'][0]  # (3, 457, 500)
        query_img = dat['first_img'][0]  # (3, 375, 500)
        query_label = dat['second_label'][0]  # (1, 375, 500)
        ref_label = dat['first_label'][0]  # (1, 457, 500)
        # query_img = dat['second_img'][0]
        # ref_img = dat['first_img'][0]
        # ref_label = dat['second_label'][0]
        # query_label = dat['first_label'][0]
        deploy_info = dat['deploy_info']
        semantic_label = deploy_info['first_semantic_labels'][0][0] - 1  # 2

        ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(
            ref_label).cuda()
        query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(
            query_label[0, :, :]).cuda()
        #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
        #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

        # ref_img = ref_img*ref_label
        ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
        query_label_var, ref_label_var = Variable(query_label), Variable(
            ref_label)

        ref_img_var = torch.unsqueeze(ref_img_var, dim=0)  # [1, 3, 457, 500]
        ref_label_var = torch.unsqueeze(ref_label_var,
                                        dim=1)  # [1, 1, 457, 500]
        query_img_var = torch.unsqueeze(query_img_var,
                                        dim=0)  # [1, 3, 375, 500]
        query_label_var = torch.unsqueeze(query_label_var,
                                          dim=0)  # [1, 375, 500]

        samples = torch.cat([ref_img_var, query_img_var], 0)
        pred = model(samples, ref_label_var)
        w, h = query_label.size()
        pred = F.upsample(pred, size=(w, h), mode='bilinear')  #[2, 416, 416]
        pred = F.softmax(pred, dim=1).squeeze()
        values, pred = torch.max(pred, dim=0)
        #print(pred.shape)
        pred = pred.data.cpu().numpy().astype(np.int32)  # (333, 500)
        #print(pred.shape)
        org_img = get_org_img(
            query_img.squeeze().cpu().data.numpy())  # 查询集的图片(375, 500, 3)
        #print(org_img.shape)
        img = mask_to_img(pred, org_img)  # (375, 500, 3)mask和原图加权后的彩色图片
        cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img)

        query_label = query_label.cpu().numpy().astype(np.int32)  # (333, 500)
        class_ind = int(deploy_info['first_semantic_labels'][0][0]
                        ) - 1  # because class indices from 1 in data layer,0
        scorer.update(pred, query_label, class_ind + 1)
        tp, tn, fp, fn = measure(query_label, pred)
        # iou_img = tp/float(max(tn+fp+fn,1))
        tp_list[class_ind] += tp
        fp_list[class_ind] += fp
        fn_list[class_ind] += fn
        # max in case both pred and label are zero
        iou_list = [
            tp_list[ic] /
            float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1))
            for ic in range(num_classes)
        ]

        tmp_pred = pred
        tmp_pred[tmp_pred > 0.5] = class_ind + 1
        tmp_gt_label = query_label
        tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1

        hist += Metrics.fast_hist(tmp_pred, query_label, 21)

    print("-------------GROUP %d-------------" % (group))
    print(iou_list)
    class_indexes = range(group * 5, (group + 1) * 5)
    print('Mean:', np.mean(np.take(iou_list, class_indexes)))
    '''
    for group in range(2):
        datalayer = SSDatalayer(group+1)
        restore(args, model, group+1)

        for count in tqdm(range(1000)):
            dat = datalayer.dequeue()
            ref_img = dat['second_img'][0]#(3, 457, 500)
            query_img = dat['first_img'][0]#(3, 375, 500)
            query_label = dat['second_label'][0]#(1, 375, 500)
            ref_label = dat['first_label'][0]#(1, 457, 500)
            # query_img = dat['second_img'][0]
            # ref_img = dat['first_img'][0]
            # ref_label = dat['second_label'][0]
            # query_label = dat['first_label'][0]
            deploy_info = dat['deploy_info']
            semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2

            ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda()
            query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda()
            #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
            #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

            # ref_img = ref_img*ref_label
            ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
            query_label_var, ref_label_var = Variable(query_label), Variable(ref_label)

            ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500]
            ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500]
            query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500]
            query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500]

            logits  = model(query_img_var, ref_img_var, ref_label_var,ref_label_var)

            # w, h = query_label.size()
            # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear')
            # out_side = F.softmax(outB_side, dim=1).squeeze()
            # values, pred = torch.max(out_side, dim=0)
            values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500]
            pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500)

            query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500)
            class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0
            scorer.update(pred, query_label, class_ind+1)
            tp, tn, fp, fn = measure(query_label, pred)
            # iou_img = tp/float(max(tn+fp+fn,1))
            tp_list[class_ind] += tp
            fp_list[class_ind] += fp
            fn_list[class_ind] += fn
            # max in case both pred and label are zero
            iou_list = [tp_list[ic] /
                        float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1))
                        for ic in range(num_classes)]


            tmp_pred = pred
            tmp_pred[tmp_pred>0.5] = class_ind+1
            tmp_gt_label = query_label
            tmp_gt_label[tmp_gt_label>0.5] = class_ind+1

            hist += Metrics.fast_hist(tmp_pred, query_label, 21)


        print("-------------GROUP %d-------------"%(group))
        print(iou_list)
        class_indexes = range(group*5, (group+1)*5)
        print('Mean:', np.mean(np.take(iou_list, class_indexes)))

    print('BMVC IOU', np.mean(np.take(iou_list, range(0,20))))

    miou = Metrics.get_voc_iou(hist)
    print('IOU:', miou, np.mean(miou))
    '''

    binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(),
                            hist[1:, 1:].sum())).reshape((2, 2))
    bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) -
                                     np.diag(binary_hist))
    print('Bin_iu:', bin_iu)

    scores = scorer.score()
    for k in scores.keys():
        print(k, np.mean(scores[k]), scores[k])
Пример #15
0
from PIL import Image
import numpy as np

import torch
import torchvision.transforms as tr

from modeling.deeplab import DeepLab
from dataloaders.utils import decode_segmap

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load("./run/pascal/deeplab-resnet/model_best.pth")

model = DeepLab(num_classes=21,
                backbone='resnet',
                output_stride=16,
                sync_bn=True,
                freeze_bn=False)

model.load_state_dict(checkpoint['state_dict_G'])
model.eval()
model.to(device)


def transform(image):
    return tr.Compose([
        tr.Resize(513),
        tr.CenterCrop(513),
        tr.ToTensor(),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])(image)
Пример #16
0
    def __init__(self, data_train, data_valid, image_base_dir, instructions):
        """

        :param data_train:
        :param data_valid:
        :param image_base_dir:
        :param instructions:
        """

        self.image_base_dir = image_base_dir
        self.data_valid = data_valid
        self.instructions = instructions

        # specify model save dir
        self.model_name = instructions[STR.MODEL_NAME]
        # now = time.localtime()
        # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min,
        #                                         now.tm_sec)
        experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH,
                                              self.model_name)

        if os.path.exists(experiment_folder_path):
            Warning(
                "Experiment folder exists already. Files might be overwritten")
        os.makedirs(experiment_folder_path, exist_ok=True)

        # define saver and save instructions
        self.saver = Saver(folder_path=experiment_folder_path,
                           instructions=instructions)
        self.saver.save_instructions()

        # define Tensorboard Summary
        self.writer = SummaryWriter(log_dir=experiment_folder_path)

        nn_input_size = instructions[STR.NN_INPUT_SIZE]
        state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None)

        self.colour_mapping = mapping.get_colour_mapping()

        # define transformers for training
        crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10)

        apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \
                                (STR.IMAGES_PER_BATCH in instructions.keys())

        print("{}applying random cropping".format(
            "" if apply_random_cropping else "_NOT_ "))

        t = [Normalize()]
        if apply_random_cropping:
            t.append(
                RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400),
                           max_size=instructions.get(STR.CROP_SIZE_MAX, 1000),
                           crop_count=crops_per_image))
        t += [
            Resize(nn_input_size),
            Flip(p_vertical=0.2, p_horizontal=0.5),
            ToTensor()
        ]

        transformations_train = transforms.Compose(t)

        # define transformers for validation
        transformations_valid = transforms.Compose(
            [Normalize(), Resize(nn_input_size),
             ToTensor()])

        # set up data loaders
        dataset_train = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_train,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_train)

        # define batch sizes
        self.batch_size = instructions[STR.BATCH_SIZE]

        if apply_random_cropping:
            self.data_loader_train = DataLoader(
                dataset=dataset_train,
                batch_size=instructions[STR.IMAGES_PER_BATCH],
                shuffle=True,
                collate_fn=custom_collate)
        else:
            self.data_loader_train = DataLoader(dataset=dataset_train,
                                                batch_size=self.batch_size,
                                                shuffle=True,
                                                collate_fn=custom_collate)

        dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_valid,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_valid)

        self.data_loader_valid = DataLoader(dataset=dataset_valid,
                                            batch_size=self.batch_size,
                                            shuffle=False,
                                            collate_fn=custom_collate)

        self.num_classes = dataset_train.num_classes()

        # define model
        print("Building model")
        self.model = DeepLab(num_classes=self.num_classes,
                             backbone=instructions.get(STR.BACKBONE, "resnet"),
                             output_stride=instructions.get(
                                 STR.DEEPLAB_OUTPUT_STRIDE, 16))

        # load weights
        if state_dict_file_path is not None:
            print("loading state_dict from:")
            print(state_dict_file_path)
            load_state_dict(self.model, state_dict_file_path)

        learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5)
        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': learning_rate
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': learning_rate
        }]

        # choose gpu or cpu
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        if instructions.get(STR.MULTI_GPU, False):
            if torch.cuda.device_count() > 1:
                print("Using ", torch.cuda.device_count(), " GPUs!")
                self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=False)

        # calculate class weights
        if instructions.get(STR.CLASS_STATS_FILE_PATH, None):

            class_weights = calculate_class_weights(
                instructions[STR.CLASS_STATS_FILE_PATH],
                self.colour_mapping,
                modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01))

            class_weights = torch.from_numpy(class_weights.astype(np.float32))
        else:
            class_weights = None
        self.criterion = SegmentationLosses(
            weight=class_weights, cuda=self.device.type != "cpu").build_loss()

        # Define Evaluator
        self.evaluator = Evaluator(self.num_classes)

        # Define lr scheduler
        self.scheduler = None
        if instructions.get(STR.USE_LR_SCHEDULER, True):
            self.scheduler = LR_Scheduler(mode="cos",
                                          base_lr=learning_rate,
                                          num_epochs=instructions[STR.EPOCHS],
                                          iters_per_epoch=len(
                                              self.data_loader_train))

        # print information before training start
        print("-" * 60)
        print("instructions")
        pprint(instructions)
        model_parameters = sum([p.nelement() for p in self.model.parameters()])
        print("Model parameters: {:.2E}".format(model_parameters))

        self.best_prediction = 0.0
Пример #17
0
class Trainer:
    def __init__(self, data_train, data_valid, image_base_dir, instructions):
        """

        :param data_train:
        :param data_valid:
        :param image_base_dir:
        :param instructions:
        """

        self.image_base_dir = image_base_dir
        self.data_valid = data_valid
        self.instructions = instructions

        # specify model save dir
        self.model_name = instructions[STR.MODEL_NAME]
        # now = time.localtime()
        # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min,
        #                                         now.tm_sec)
        experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH,
                                              self.model_name)

        if os.path.exists(experiment_folder_path):
            Warning(
                "Experiment folder exists already. Files might be overwritten")
        os.makedirs(experiment_folder_path, exist_ok=True)

        # define saver and save instructions
        self.saver = Saver(folder_path=experiment_folder_path,
                           instructions=instructions)
        self.saver.save_instructions()

        # define Tensorboard Summary
        self.writer = SummaryWriter(log_dir=experiment_folder_path)

        nn_input_size = instructions[STR.NN_INPUT_SIZE]
        state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None)

        self.colour_mapping = mapping.get_colour_mapping()

        # define transformers for training
        crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10)

        apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \
                                (STR.IMAGES_PER_BATCH in instructions.keys())

        print("{}applying random cropping".format(
            "" if apply_random_cropping else "_NOT_ "))

        t = [Normalize()]
        if apply_random_cropping:
            t.append(
                RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400),
                           max_size=instructions.get(STR.CROP_SIZE_MAX, 1000),
                           crop_count=crops_per_image))
        t += [
            Resize(nn_input_size),
            Flip(p_vertical=0.2, p_horizontal=0.5),
            ToTensor()
        ]

        transformations_train = transforms.Compose(t)

        # define transformers for validation
        transformations_valid = transforms.Compose(
            [Normalize(), Resize(nn_input_size),
             ToTensor()])

        # set up data loaders
        dataset_train = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_train,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_train)

        # define batch sizes
        self.batch_size = instructions[STR.BATCH_SIZE]

        if apply_random_cropping:
            self.data_loader_train = DataLoader(
                dataset=dataset_train,
                batch_size=instructions[STR.IMAGES_PER_BATCH],
                shuffle=True,
                collate_fn=custom_collate)
        else:
            self.data_loader_train = DataLoader(dataset=dataset_train,
                                                batch_size=self.batch_size,
                                                shuffle=True,
                                                collate_fn=custom_collate)

        dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir,
                                         data=data_valid,
                                         num_classes=len(
                                             self.colour_mapping.keys()),
                                         transformation=transformations_valid)

        self.data_loader_valid = DataLoader(dataset=dataset_valid,
                                            batch_size=self.batch_size,
                                            shuffle=False,
                                            collate_fn=custom_collate)

        self.num_classes = dataset_train.num_classes()

        # define model
        print("Building model")
        self.model = DeepLab(num_classes=self.num_classes,
                             backbone=instructions.get(STR.BACKBONE, "resnet"),
                             output_stride=instructions.get(
                                 STR.DEEPLAB_OUTPUT_STRIDE, 16))

        # load weights
        if state_dict_file_path is not None:
            print("loading state_dict from:")
            print(state_dict_file_path)
            load_state_dict(self.model, state_dict_file_path)

        learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5)
        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': learning_rate
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': learning_rate
        }]

        # choose gpu or cpu
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        if instructions.get(STR.MULTI_GPU, False):
            if torch.cuda.device_count() > 1:
                print("Using ", torch.cuda.device_count(), " GPUs!")
                self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=0.9,
                                         weight_decay=5e-4,
                                         nesterov=False)

        # calculate class weights
        if instructions.get(STR.CLASS_STATS_FILE_PATH, None):

            class_weights = calculate_class_weights(
                instructions[STR.CLASS_STATS_FILE_PATH],
                self.colour_mapping,
                modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01))

            class_weights = torch.from_numpy(class_weights.astype(np.float32))
        else:
            class_weights = None
        self.criterion = SegmentationLosses(
            weight=class_weights, cuda=self.device.type != "cpu").build_loss()

        # Define Evaluator
        self.evaluator = Evaluator(self.num_classes)

        # Define lr scheduler
        self.scheduler = None
        if instructions.get(STR.USE_LR_SCHEDULER, True):
            self.scheduler = LR_Scheduler(mode="cos",
                                          base_lr=learning_rate,
                                          num_epochs=instructions[STR.EPOCHS],
                                          iters_per_epoch=len(
                                              self.data_loader_train))

        # print information before training start
        print("-" * 60)
        print("instructions")
        pprint(instructions)
        model_parameters = sum([p.nelement() for p in self.model.parameters()])
        print("Model parameters: {:.2E}".format(model_parameters))

        self.best_prediction = 0.0

    def train(self, epoch):
        self.model.train()
        train_loss = 0.0

        # create a progress bar
        pbar = tqdm(self.data_loader_train)
        num_batches_train = len(self.data_loader_train)

        # go through each item in the training data
        for i, sample in enumerate(pbar):
            # set input and target
            nn_input = sample[STR.NN_INPUT].to(self.device)
            nn_target = sample[STR.NN_TARGET].to(self.device,
                                                 dtype=torch.float)

            if self.scheduler:
                self.scheduler(self.optimizer, i, epoch, self.best_prediction)

            # run model
            output = self.model(nn_input)

            # calc losses
            loss = self.criterion(output, nn_target)
            # # save step losses
            # combined_loss_steps.append(float(loss))
            # regression_loss_steps.append(float(regression_loss))
            # classification_loss_steps.append(float(classification_loss))

            train_loss += loss.item()
            pbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_batches_train * epoch)

            # calculate gradient and update model weights
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("[Epoch: {}, num images/crops: {}]".format(
            epoch, num_batches_train * self.batch_size))

        print("Loss: {:.2f}".format(train_loss))

    def validation(self, epoch):

        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0

        pbar = tqdm(self.data_loader_valid, desc='\r')
        num_batches_val = len(self.data_loader_valid)

        for i, sample in enumerate(pbar):
            # set input and target
            nn_input = sample[STR.NN_INPUT].to(self.device)
            nn_target = sample[STR.NN_TARGET].to(self.device,
                                                 dtype=torch.float)

            with torch.no_grad():
                output = self.model(nn_input)

            loss = self.criterion(output, nn_target)
            test_loss += loss.item()
            pbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            nn_target = nn_target.cpu().numpy()
            # Add batch sample into evaluator
            self.evaluator.add_batch(nn_target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print("[Epoch: {}, num crops: {}]".format(
            epoch, num_batches_val * self.batch_size))
        print(
            "Acc:{:.2f}, Acc_class:{:.2f}, mIoU:{:.2f}, fwIoU: {:.2f}".format(
                Acc, Acc_class, mIoU, FWIoU))
        print("Loss: {:.2f}".format(test_loss))

        new_pred = mIoU
        is_best = new_pred > self.best_prediction
        if is_best:
            self.best_prediction = new_pred
        self.saver.save_checkpoint(self.model, is_best, epoch)
Пример #18
0
def main():

    here = osp.dirname(osp.abspath(__file__))

    trainOpts = TrainOptions()
    args = trainOpts.get_arguments()

    now = datetime.datetime.now()
    args.out = osp.join(
        here, 'results',
        args.model + '_' + args.dataset + '_' + now.strftime('%Y%m%d__%H%M%S'))

    if not osp.isdir(args.out):
        os.makedirs(args.out)

    log_file = osp.join(args.out, args.model + '_' + args.dataset + '.log')
    mylog = open(log_file, 'w')

    checkpoint_dir = osp.join(args.out, 'checkpoints')
    os.makedirs(checkpoint_dir)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    # MAIN_FOLDER = args.folder + 'Vaihingen/'
    # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
    # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'

    # train_ids = ['1', '3', '5', '21','23', '26', '7',  '13',  '17', '32', '37']
    # val_ids =['11','15', '28', '30', '34']

    # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache)
    # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size)

    # MAIN_FOLDER = args.folder + 'Potsdam_multiscale/'
    # DATA_FOLDER1 = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif'
    # LABEL_FOLDER1 = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif'

    # DATA_FOLDER2 = MAIN_FOLDER + '3_Ortho_IRRG_2/top_potsdam_{}_IRRG.tif'
    # LABEL_FOLDER2 = MAIN_FOLDER + '5_Labels_for_participants_2/top_potsdam_{}_label.tif'

    # train_ids=['2_10','3_10','3_11','3_12','4_11','4_12','5_10','5_12',\
    # '6_8','6_9','6_10','6_11','6_12','7_7','7_9','7_11','7_12']
    # val_ids=[ '2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10']

    # target_set = ISPRS_dataset_multi(2,train_ids, DATA_FOLDER1, LABEL_FOLDER1,DATA_FOLDER2, LABEL_FOLDER2,cache=args.cache)
    # target_loader = torch.utils.data.DataLoader(target_set,batch_size=args.batch_size)

    # MAIN_FOLDER = args.folder + 'Potsdam/'
    # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif'
    # LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif'
    # ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif'

    # train_ids=['2_10','3_10','3_11','3_12','4_11','4_12','5_10','5_12',\
    # '6_8','6_9','6_10','6_11','6_12','7_7','7_9','7_11','7_12']
    # val_ids=[ '2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10']

    # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache)
    # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size)

    # MAIN_FOLDER = args.folder + 'Vaihingen/'
    # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
    # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'

    # train_ids = ['1', '3', '5', '21','23', '26', '7',  '13',  '17', '32', '37']
    # val_ids =['11','15', '28', '30', '34']

    # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache)
    # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size)

    MAIN_FOLDER = args.folder + 'DeepGlobe/land-train_crop/'
    DATA_FOLDER = MAIN_FOLDER + '{}_sat.jpg'
    LABEL_FOLDER = MAIN_FOLDER + '{}_mask.png'
    all_files = sorted(glob(DATA_FOLDER.replace('{}', '*')))

    all_ids = [f.split('/')[-1].split('_')[0] for f in all_files]
    train_ids = all_ids[:int(len(all_ids) / 3 * 2)]
    val_ids = all_ids[int(len(all_ids) / 3 * 2):]

    train_set = DeepGlobe_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size)

    MAIN_FOLDER = args.folder + 'ISPRS_dataset/Vaihingen/'
    DATA_FOLDER1 = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
    LABEL_FOLDER1 = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'

    DATA_FOLDER2 = MAIN_FOLDER + 'resized_resolution5/top_mosaic_09cm_area{}.tif'
    LABEL_FOLDER2 = MAIN_FOLDER + 'gts_for_participants5/top_mosaic_09cm_area{}.tif'

    train_ids = ['1', '3', '5', '21', '23', '26', '7', '13', '17', '32', '37']
    val_ids = ['11', '15', '28', '30', '34']

    target_set = ISPRS_dataset_multi(5,
                                     train_ids,
                                     DATA_FOLDER1,
                                     LABEL_FOLDER1,
                                     DATA_FOLDER2,
                                     LABEL_FOLDER2,
                                     cache=args.cache)
    target_loader = torch.utils.data.DataLoader(target_set,
                                                batch_size=args.batch_size)

    # val_set = ISPRS_dataset(val_ids, DATA_FOLDER1, LABEL_FOLDER1,cache=args.cache)
    # val_loader = torch.utils.data.DataLoader(val_set,batch_size=args.batch_size)

    LABELS = [
        "roads", "buildings", "low veg.", "trees", "cars", "clutter", "unknown"
    ]  # Label names
    N_CLASS = len(LABELS)  # Number of classes

    # 2. model

    if args.backbone == 'resnet':
        model = DeepLab(num_classes=N_CLASS,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
    elif args.backbone == 'resnet_multiscale':
        model = DeepLabCA(num_classes=N_CLASS,
                          backbone=args.backbone,
                          output_stride=args.out_stride,
                          sync_bn=args.sync_bn,
                          freeze_bn=args.freeze_bn)
    else:
        print('backbone not exists!')

    train_params = [{
        'params': model.get_1x_lr_params(),
        'lr': args.lr
    }, {
        'params': model.get_10x_lr_params(),
        'lr': args.lr * 10
    }]

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    lr = args.lr
    # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
    #     momentum=args.momentum, weight_decay=args.weight_decay)
    netD_domain = FCDiscriminator(num_classes=N_CLASS)
    netD_scale = FCDiscriminator(num_classes=N_CLASS)

    optim_netG = torch.optim.SGD(train_params,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=args.nesterov)
    optim_netD_domain = optim.Adam(netD_domain.parameters(),
                                   lr=args.lr_D,
                                   betas=(0.9, 0.99))
    optim_netD_scale = optim.Adam(netD_scale.parameters(),
                                  lr=args.lr_D,
                                  betas=(0.9, 0.99))

    if cuda:
        model, netD_domain, netD_scale = model.cuda(), netD_domain.cuda(
        ), netD_scale.cuda()

    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint)

    bce_loss = torch.nn.BCEWithLogitsLoss()
    # 4. training
    iter_ = 0
    no_optim = 0
    val_best_loss = float('Inf')
    factor = 10

    max_iter = 50000

    trainloader_iter = enumerate(train_loader)
    targetloader_iter = enumerate(target_loader)

    source_label = 0
    target_label = 1

    source_scale_label = 0
    target_scale_label = 1

    train_loss = []
    train_acc = []
    target_acc_s1 = []
    target_acc_s2 = []

    while iter_ < max_iter:

        optim_netG.zero_grad()

        adjust_learning_rate(optim_netG, iter_, args)

        optim_netD_domain.zero_grad()
        optim_netD_scale.zero_grad()
        adjust_learning_rate_D(optim_netD_domain, iter_, args)
        adjust_learning_rate_D(optim_netD_scale, iter_, args)

        if iter_ % 1000 == 0:
            train_loss = []
            train_acc = []
            target_acc_s1 = []
            target_acc_s2 = []

        for param in netD_domain.parameters():
            param.requires_grad = False

        for param in netD_scale.parameters():
            param.requires_grad = False

        _, batch = trainloader_iter.__next__()
        im_s, label_s = batch

        _, batch = targetloader_iter.__next__()

        im_t_s1, label_t_s1, im_t_s2, label_t_s2 = batch

        if cuda:
            im_s, label_s = im_s.cuda(), label_s.cuda()
            im_t_s1, label_t_s1, im_t_s2, label_t_s2 = im_t_s1.cuda(
            ), label_t_s1.cuda(), im_t_s2.cuda(), label_t_s2.cuda()

        ############
        #TRAIN NETG#
        ############
        #train with source
        #optimize segmentation network with source data

        pred_seg = model(im_s)
        seg_loss = cross_entropy2d(pred_seg, label_s)
        seg_loss /= len(im_s)
        loss_data = seg_loss.data.item()
        if np.isnan(loss_data):
            # continue
            raise ValueError('loss is nan while training')
        seg_loss.backward()

        # import pdb
        # pdb.set_trace()
        pred = np.argmax(pred_seg.data.cpu().numpy()[0], axis=0)
        gt = label_s.data.cpu().numpy()[0]

        train_acc.append(accuracy(pred, gt))
        train_loss.append(loss_data)

        #train with target
        pred_s1 = model(im_t_s1)
        pred = np.argmax(pred_s1.data.cpu().numpy()[0], axis=0)
        gt = label_t_s1.data.cpu().numpy()[0]
        target_acc_s1.append(accuracy(pred, gt))

        pred_s2 = model(im_t_s2)
        pred = np.argmax(pred_s2.data.cpu().numpy()[0], axis=0)
        gt = label_t_s2.data.cpu().numpy()[0]
        target_acc_s2.append(accuracy(pred, gt))

        pred_d = netD_domain(F.softmax(pred_s1))
        pred_s = netD_scale(F.softmax(pred_s2))

        loss_adv_domain = bce_loss(
            pred_d,
            Variable(
                torch.FloatTensor(
                    pred_d.data.size()).fill_(source_label)).cuda())
        loss_adv_scale = bce_loss(
            pred_s,
            Variable(
                torch.FloatTensor(
                    pred_s.data.size()).fill_(source_scale_label)).cuda())

        loss = args.lambda_adv_domain * loss_adv_domain + args.lambda_adv_scale * loss_adv_scale
        loss /= len(im_t_s1)
        loss.backward()

        ############
        #TRAIN NETD#
        ############
        for param in netD_domain.parameters():
            param.requires_grad = True

        for param in netD_scale.parameters():
            param.requires_grad = True

        #train with source domain and source scale
        pred_seg, pred_s1 = pred_seg.detach(), pred_s1.detach()
        pred_d = netD_domain(F.softmax(pred_seg))
        # pred_s=netD_scale(F.softmax(pred_seg))
        pred_s = netD_scale(F.softmax(pred_s1))

        loss_D_domain = bce_loss(
            pred_d,
            Variable(
                torch.FloatTensor(
                    pred_d.data.size()).fill_(source_label)).cuda())
        loss_D_scale = bce_loss(
            pred_s,
            Variable(
                torch.FloatTensor(
                    pred_s.data.size()).fill_(source_scale_label)).cuda())

        loss_D_domain = loss_D_domain / len(im_s) / 2
        loss_D_scale = loss_D_scale / len(im_s) / 2

        loss_D_domain.backward()
        loss_D_scale.backward()

        #train with target domain and target scale
        pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach()
        pred_d = netD_domain(F.softmax(pred_s1))
        pred_s = netD_scale(F.softmax(pred_s2))

        loss_D_domain = bce_loss(
            pred_d,
            Variable(
                torch.FloatTensor(
                    pred_d.data.size()).fill_(target_label)).cuda())
        loss_D_scale = bce_loss(
            pred_s,
            Variable(
                torch.FloatTensor(
                    pred_s.data.size()).fill_(target_scale_label)).cuda())

        loss_D_domain = loss_D_domain / len(im_s) / 2
        loss_D_scale = loss_D_scale / len(im_s) / 2

        loss_D_domain.backward()
        loss_D_scale.backward()

        optim_netG.step()
        optim_netD_domain.step()
        optim_netD_scale.step()

        if iter_ % 100 == 0:
            print(
                'Train [{}/{} Source loss:{:.6f} acc:{:.4f} % Target s1 acc:{:4f}% Target s2 acc:{:4f}%]'
                .format(iter_, max_iter,
                        sum(train_loss) / len(train_loss),
                        sum(train_acc) / len(train_acc),
                        sum(target_acc_s1) / len(target_acc_s1),
                        sum(target_acc_s2) / len(target_acc_s2)))
            print(
                'Train  [{}/{} Source loss:{:.6f} acc:{:.4f} % Target s1 acc:{:4f}% Target s2 acc:{:4f}%]'
                .format(iter_, max_iter,
                        sum(train_loss) / len(train_loss),
                        sum(train_acc) / len(train_acc),
                        sum(target_acc_s1) / len(target_acc_s1),
                        sum(target_acc_s2) / len(target_acc_s2)),
                file=mylog)

        if iter_ % 1000 == 0:
            print('saving checkpoint.....')
            torch.save(model.state_dict(),
                       osp.join(checkpoint_dir, 'iter{}.pth'.format(iter_)))

        iter_ += 1
Пример #19
0
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument(
        '--pretrained',
        type=str,
        default='/Users/yulian/Downloads/mixup_model_best.pth.tar',
        help='pretrained model')
    parser.add_argument('--color',
                        type=str,
                        default='purple',
                        choices=['purple', 'green', 'blue', 'red'],
                        help='Color your hair (default: purple)')
    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = DeepLab(backbone=args.backbone,
                  output_stride=16,
                  num_classes=2,
                  sync_bn=False).to(device)
    net.load_state_dict(
        torch.load(args.pretrained, map_location=device)['state_dict'])
    net.eval()
    cam = cv2.VideoCapture(0)
    if not cam.isOpened():
        raise Exception("webcam is not detected")

    while (True):
        # ret : frame capture(boolean)
        # frame : Capture frame
        ret, image = cam.read()

        if (ret):
            image, mask = get_image_mask(image, net)
Пример #20
0
cp_seg_name = '/_checkpoint/cp_seg.pth'
cp_act_name = '/_checkpoint/cp_act.pth'
save_seg_name = '/model.pth'
save_act_name = '/action.pth'

# segmentation
import torch
from modeling.deeplab import DeepLab

n_class = 10

try:
    model = DeepLab(num_classes=n_class,
                    backbone='xception',
                    output_stride=16,
                    sync_bn=bool(None),
                    freeze_bn=bool(False))
    model = model.cuda()

    checkpoint = torch.load(now_dir + cp_seg_name)
    model.load_state_dict(checkpoint['state_dict'])
    torch.save(model, now_dir + target_dir + save_seg_name)
    print('segmentation model - OK!')
except:
    print('segmentation model - Failed!')

# action
import torchvision.models as models
import torch.nn as nn
Пример #21
0
def main(args):
    config = ConfigParser(args)
    cfg = config.config
    logger = get_logger(config.log_dir, "train")

    train_dataset = MRIBrainSegmentation(root_folder=cfg['root_folder'],
                                         image_label=cfg['train_data'],
                                         is_train=True,
                                         ignore_label=0,
                                         input_size=cfg['input_size'])
    vali_dataset = MRIBrainSegmentation(root_folder=cfg['root_folder'],
                                        image_label=cfg['validation_data'],
                                        is_train=False,
                                        ignore_label=0,
                                        input_size=cfg['input_size'])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg["train_batch_size"],
        shuffle=True,
        num_workers=cfg["workers"],
        drop_last=True)

    vali_loader = torch.utils.data.DataLoader(
        vali_dataset,
        batch_size=cfg["vali_batch_size"],
        shuffle=False,
        num_workers=cfg["workers"],
        drop_last=False)
    if cfg['net_name'] == "deeplab":
        model = DeepLab(num_classes=1,
                        backbone=cfg['backbone'],
                        output_stride=cfg['output_stride'],
                        sync_bn=cfg['sync_bn'],
                        freeze_bn=cfg['freeze_bn'])
    else:
        model = Unet(in_channels=3, out_channels=1, init_features=32)

    criterion = getattr(loss, 'dice_loss')
    optimizer = optim.SGD(model.parameters(),
                          lr=cfg["lr"],
                          momentum=0.9,
                          weight_decay=cfg["weight_decay"])
    metrics_name = []
    scheduler = Poly_Scheduler(base_lr=cfg['lr'],
                               num_epochs=config['epoch'],
                               iters_each_epoch=len(train_loader))
    trainer = Trainer(model=model,
                      criterion=criterion,
                      optimizer=optimizer,
                      train_loader=train_loader,
                      nb_epochs=config['epoch'],
                      valid_loader=vali_loader,
                      lr_scheduler=scheduler,
                      logger=logger,
                      log_dir=config.save_dir,
                      metrics_name=metrics_name,
                      resume=config['resume'],
                      save_dir=config.save_dir,
                      device="cuda:0",
                      monitor="max iou_class_1",
                      early_stop=-1)
    trainer.train()
Пример #22
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        # init D
        model_D = FCDiscriminator(num_classes=19)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                                       lr=1e-4,
                                       betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.model, self.optimizer = model, optimizer
        self.model_D, self.optimizer_D = model_D, optimizer_D

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model_D = torch.nn.DataParallel(self.model_D,
                                                 device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            patch_replication_callback(self.model_D)
            self.model = self.model.cuda()
            self.model_D = self.model_D.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
class DeeplabRos:
    def __init__(self):

        #GPU assignment
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        #Load checkpoint
        self.checkpoint = torch.load(
            os.path.join("./src/deeplab_ros/data/model_best.pth.tar"))

        #Load Model
        self.model = DeepLab(num_classes=4,
                             backbone='mobilenet',
                             output_stride=16,
                             sync_bn=True,
                             freeze_bn=False)

        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model = self.model.to(self.device)

        #ROS init
        self.bridge = CvBridge()
        self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw",
                                          ImageMsg,
                                          self.callback,
                                          queue_size=1,
                                          buff_size=2**24)
        self.image_pub = rospy.Publisher("segmentation_image",
                                         ImageMsg,
                                         queue_size=1)

    def callback(self, data):

        cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        start_time = time.time()

        self.model.eval()
        torch.set_grad_enabled(False)

        tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        inputs = tfms(cv_image).to(self.device)
        output = self.model(inputs.unsqueeze(0)).squeeze().cpu().numpy()
        pred = np.argmax(output, axis=0)
        pred_img = self.label_to_color_image(pred)

        msg = self.bridge.cv2_to_imgmsg(pred_img, "bgr8")

        inference_time = time.time() - start_time
        print("inference time: ", inference_time)

        self.image_pub.publish(msg)

    def label_to_color_image(self, pred, class_num=4):
        label_colors = np.array([(0, 0, 0), (0, 0, 128), (0, 128, 0),
                                 (128, 0, 0)])  #bgr
        # Unlabeled, Building, Lane-marking, Fence
        r = np.zeros_like(pred).astype(np.uint8)
        g = np.zeros_like(pred).astype(np.uint8)
        b = np.zeros_like(pred).astype(np.uint8)

        for i in range(0, class_num):
            idx = pred == i
            r[idx] = label_colors[i, 0]
            g[idx] = label_colors[i, 1]
            b[idx] = label_colors[i, 2]

        rgb = np.stack([r, g, b], axis=2)

        return rgb
Пример #24
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        #model = DeeplabMulti(num_classes=args.num_classes)
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=16)
        '''
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #restore(model, saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc':
                #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                new_params[i] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        '''
    else:
        raise NotImplementedError

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)

    # if args.restore_from_D[:4] == 'http':
    #     saved_state_dict = model_zoo.load_url(args.restore_from_D)
    # else:
    #     saved_state_dict = torch.load(args.restore_from_D)
    #     ### for running different versions of pytorch
    # model_dict = model_D1.state_dict()
    # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    # model_dict.update(saved_state_dict)
    # model_D1.load_state_dict(saved_state_dict)

    model_D1.train()
    model_D1.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_loader = data_loader(args)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss()

    interp = nn.Upsample(size=(416, 416), mode='bilinear', align_corners=True)
    #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training

    # set up tensorboard
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    count = args.start_count  # 迭代次数
    for dat in train_loader:
        if count > args.num_steps:
            break

        loss_seg_value1_anchor = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, count)

        optimizer_D1.zero_grad()

        adjust_learning_rate_D(optimizer_D1, count)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类,
            # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片,
            # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片
            #############################
            anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat  # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类)

            anchor_img, anchor_mask, pos_img, pos_mask, \
                = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda()  # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500]

            anchor_mask = torch.unsqueeze(anchor_mask,
                                          dim=1)  # [1, 1, 386, 500]
            pos_mask = torch.unsqueeze(pos_mask, dim=1)  # [1,1, 374, 500]
            samples = torch.cat([pos_img, anchor_img], 0)

            pred = model(samples, pos_mask)  ##[2, 2, 53, 53],#[2, 2, 53, 53]
            pred = interp(pred)

            loss_seg1_anchor = seg_loss(
                pred,
                anchor_mask.squeeze().unsqueeze(0).long())
            D_out1 = model_D1(F.softmax(pred))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(1).to(
                    device))  # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失
            '''
            s = torch.stack([s, 1-s])
            loss_s = seg_loss()
            '''
            loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D# bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            # train with anchor
            pred_target1 = pred.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(0).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with GT
            anchor_gt = Variable(one_hot(anchor_mask)).cuda()
            D_out1 = model_D1(anchor_gt)
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(1).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        count = count + 1
        if args.tensorboard:
            scalar_info = {
                'loss_seg1_anchor': loss_seg_value1_anchor,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_D1': loss_D_value1,
            }

            if count % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, count)

        # print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
            .format(count, args.num_steps, loss_seg_value1_anchor,
                    loss_adv_target_value1, loss_D_value1))

        if count >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(args.num_steps_stop) + '_D1.pth'))
            break

        if count % args.save_pred_every == 0 and count != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(count) + '_D1.pth'))

    if args.tensorboard:
        writer.close()
from modeling.deeplab import DeepLab
import kornia
from PIL import Image
import torch
import torchvision.transforms.functional as TF
import numpy as np

# this example only uses 1 image, so cpu is fine
device = torch.device("cpu")

# load pre-trained weights, set network to inference mode
network = DeepLab(num_classes=18)
network.load_state_dict(
    torch.load("segmentation-model/epoch-14", map_location="cpu"))
network.eval()
network.to(device)

# load example image. the image is resized because DeepLab uses
# a lot of dilated convolutions and doesn't work very well for
# low resolution images.
image = Image.open("nate.jpg")
scaled_image = image.resize((418, 512), resample=Image.LANCZOS)
image_tensor = TF.to_tensor(scaled_image)

# send the input through the network. unsqueeze is used to
# add a batch dimension, because torch always expects a batch
# but in this case it's just one image
# I then use Kornia to resize the mask back to 218x178 then
# squeeze to remove the batch channel again (kornia also
# always expects a batch dimension)
with torch.no_grad():
Пример #26
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        # model = DeeplabMulti(num_classes=args.num_classes)
        # model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        '''
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #restore(model, saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc':
                #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                new_params[i] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        '''

    # saved_state_dict = torch.load(args.restore_from)
    # ### for running different versions of pytorch
    # model_dict = model.state_dict()
    # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    # model_dict.update(saved_state_dict)
    # model.load_state_dict(saved_state_dict)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    #model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)

    # saved_state_dict = torch.load(args.restore_from_D)
    # ### for running different versions of pytorch
    # model_dict = model_D1.state_dict()
    # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    # model_dict.update(saved_state_dict)
    # model_D1.load_state_dict(saved_state_dict)

    # model_D1.train()
    # model_D1.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_loader = data_loader(args)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    # optimizer_D1.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    #seg_loss = FocalLoss2d(gamma=2.0, weight=0.75).to(device)#alpha是用来衡量样本的正负样本不平衡的
    #seg_loss = FocalLoss2d(gamma=2.0, weight=0.75).to(device)
    # seg_loss = FocalLoss(alpha=0.75, logits=True)
    #seg_loss = FocalLoss(class_num=2).to(device)
    seg_loss = torch.nn.CrossEntropyLoss()
    affinity_loss = AffinityFieldLoss(kl_margin=3.)
    R_loss = torch.nn.MSELoss()
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    count = args.start_count  # 迭代次数
    for dat in train_loader:
        if count > args.num_steps:
            break

        loss_seg_value1_anchor = 0
        loss_adv_target_value1 = 0
        loss_affinity_value1_anchor = 0
        loss_D_value1 = 0
        loss_R_values = 0
        loss_A_values = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, count)

        # optimizer_D1.zero_grad()
        # adjust_learning_rate_D(optimizer_D1, count)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            # for param in model_D1.parameters():
            #     param.requires_grad = False

            # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类,
            # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片,
            # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片
            #############################
            anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat  # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类)

            anchor_img, anchor_mask, pos_img, pos_mask, \
                = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda()  # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500]

            anchor_mask = torch.unsqueeze(anchor_mask, dim=1)  # [1, 1, 386, 500]
            pos_mask = torch.unsqueeze(pos_mask, dim=1)  # [1,1, 374, 500]
            samples = torch.cat([pos_img, anchor_img], 0)
            if count == 5134:
                import matplotlib.pyplot as plt
                plt.imshow(pos_img[0][0].cpu().detach().numpy())
                plt.show()
                plt.imshow(pos_mask[0][0].cpu().detach().numpy())
                plt.show()

            pred = model(samples, pos_mask)  ##[2, 2, 53, 53],#[2, 2, 53, 53]#[1,2704,2704]#[1,52,52]
            _, _, w1, h1 = pred.size()
            _, _, mask_w, mask_h = anchor_mask.size()
            ####################分割loss和对抗loss#############################################
            pred = F.interpolate(pred, [mask_w, mask_h], mode='bilinear', align_corners=False)
            # loss_seg1_anchor = seg_loss(pred.squeeze(), anchor_mask.squeeze())###针对BCELOSS
            loss_seg1_anchor = seg_loss(pred, anchor_mask.squeeze().unsqueeze(0).long())  ##SOFTMAX
            loss_affinity = affinity_loss(pred, anchor_mask.squeeze().unsqueeze(0).long())
            # D_out1 = model_D1(F.softmax(pred))
            # loss_adv_target1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to(
            #     device))  # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失
            #########################关系矩阵损失,R#################################################
            # G_q = F.interpolate(anchor_mask, [w1, h1], mode='bilinear', align_corners=True)
            # G_s = F.interpolate(pos_mask, [w1, h1], mode='bilinear', align_corners=True)
            # R_gt = G_q.reshape(w1 * h1, -1) * G_s.reshape(-1, w1 * h1)
            # loss_R1 = R_loss(R1.squeeze(), R_gt)
            # loss_R2 = R_loss(R2.squeeze(), R_gt)
            ##########################注意力矩阵A loss####################################################
            '''
            A1 = torch.cat([1 - A1, A1], 0)
            A1 = interp(A1.unsqueeze(0))
            A2 = torch.cat([1 - A2, A2], 0)
            A2 = interp(A2.unsqueeze(0))
            loss_A1 = seg_loss(A1, anchor_mask.squeeze().unsqueeze(0).long())
            loss_A2 = seg_loss(A2, anchor_mask.squeeze().unsqueeze(0).long())
            '''
            #######################总的loss#############################################
            # loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1 + 0.3 * loss_R1 +0.3 * loss_R2 + 0.2 * loss_A1 + 0.2 * loss_A2
            loss = loss_seg1_anchor + loss_affinity
            # proper normalization
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size
            loss_affinity_value1_anchor += loss_affinity.item() / args.iter_size
            # loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            # loss_R_values += loss_R1.item() / args.iter_size
            # loss_R_values += loss_R2.item() / args.iter_size
            # loss_A_values += loss_A1.item() / args.iter_size
            # loss_A_values += loss_A2.item() / args.iter_size

            # train D# bring back requires_grad
            # for param in model_D1.parameters():
            #     param.requires_grad = True
            #
            # # train with anchor
            # pred_target1 = pred.detach()
            # D_out1 = model_D1(F.softmax(pred_target1))
            # loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(0).to(device))
            # loss_D1 = loss_D1 / args.iter_size / 2
            # loss_D1.backward()
            # loss_D_value1 += loss_D1.item()
            #
            # # train with GT
            # anchor_gt = Variable(one_hot(anchor_mask)).cuda()
            # D_out1 = model_D1(anchor_gt)
            # loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to(device))
            # loss_D1 = loss_D1 / args.iter_size / 2
            # loss_D1.backward()
            # loss_D_value1 += loss_D1.item()

        optimizer.step()
        # optimizer_D1.step()

        count = count + 1
        if args.tensorboard:
            scalar_info = {

                'loss_seg1_anchor': loss_seg_value1_anchor,
                'loss_affinity_anchor': loss_affinity_value1_anchor,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_D1': loss_D_value1,
                'loss_R': loss_R_values,
                'loss_A': loss_A_values
            }

            if count % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, count)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_affinity = {3:.3f}, loss_D1 = {4:.3f}, loss_R = {5:.3f} loss_A = {6:.3f}'.format(
                count, args.num_steps, loss_seg_value1_anchor, loss_affinity_value1_anchor, loss_D_value1, loss_R_values,
                loss_A_values))

        if count >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(),
                       osp.join(args.snapshot_dir, 'voc2012_1_' + str(args.num_steps_stop) + '.pth'))
            # torch.save(model_D1.state_dict(),
            #            osp.join(args.snapshot_dir, 'voc2012_1_' + str(args.num_steps_stop) + '_D1.pth'))
            break

        if count % args.save_pred_every == 0 and count != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_1_' + str(count) + '.pth'))
            # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_1_' + str(count) + '_D1.pth'))

    if args.tensorboard:
        writer.close()
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out_stride',
                        type=int,
                        default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--rip_mode', type=str, default='patches-level2')
    parser.add_argument('--use_sbd',
                        action='store_true',
                        default=True,
                        help='whether to use SBD dataset (default: True)')
    parser.add_argument('--workers',
                        type=int,
                        default=8,
                        metavar='N',
                        help='dataloader threads')
    parser.add_argument('--base_size',
                        type=int,
                        default=800,
                        help='base image size')
    parser.add_argument('--crop_size',
                        type=int,
                        default=800,
                        help='crop image size')
    parser.add_argument('--sync_bn',
                        type=bool,
                        default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument(
        '--freeze_bn',
        type=bool,
        default=False,
        help='whether to freeze bn parameters (default: False)')
    # cuda, seed and logging
    parser.add_argument('--gpus',
                        type=int,
                        default=1,
                        help='how many gpus to use (default=1)')
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkname',
                        type=str,
                        default=None,
                        help='set the checkpoint name')

    parser.add_argument('--exp_root', type=str, default='')
    args = parser.parse_args()

    args.device, args.cuda = get_available_device(args.gpus)

    nclass = 3

    model = DeepLab(num_classes=nclass,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)

    args.checkname = '/data2/data2/zewei/exp/RipData/DeepLabV3/patches/level2/CV5-1/model_best.pth.tar'
    ckpt = torch.load(args.checkname)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    model = model.to(args.device)

    img_files = ['doc/tests/img_cv.png']
    out_file = 'doc/tests/img_seg.png'

    transforms = Compose([
        ToTensor(),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    color_map = get_rip_labels()
    img_cv = cv2.imread(img_files[0])
    pred = process_single_large_image(model, img_cv)
    mask = gen_mask(pred, nclass, color_map)
    out_img = composite_image(img_cv, mask, alpha=0.2)
    save_image(mask, out_file.split('.')[0] + f'_mask.png')
    save_image(out_img, out_file.split('.')[0] + f'_com.png')
    print(f'saved image {out_file}')

    with torch.no_grad():
        for img_file in img_files:
            name, ext = img_file.split('.')

            img_cv = cv2.imread(img_file)
            patches = decompose_image(img_cv, None, (800, 800), (300, 700))
            print(f'Decompose input image into {len(patches)} patches.')
            for i, patch in patches.items():
                img = transforms(patch.image)
                img = torch.stack([img], dim=0).cuda()

                output = model(img)
                output = output.data.cpu().numpy()
                pred = np.argmax(output, axis=1)

                expanded_pred = torch.zeros()

                # out_img = output[0].cpu().permute((1, 2, 0)).numpy()
                # out_img = (out_img * 255).astype(np.uint8)
                mask = gen_mask(pred[0], nclass, color_map)
                out_img = composite_image(patch.image, mask, alpha=0.2)
                save_image(mask, name + f'_patch{i:02d}_seg.' + ext)
                save_image(out_img, name + f'_patch{i:02d}_seg_img.' + ext)
                print(f'saved image {out_file}')
Пример #28
0
                        required=True)

    parser.add_argument('--output',
                        '-o',
                        metavar='output_path',
                        help='Output image',
                        required=True)

    args = parser.parse_args()

    dataset = "fashion_clothes"
    path = "./bestmodels/deep_clothes/checkpoint.pth.tar"
    nclass = 7

    #Initialize the DeeplabV3+ model
    model = DeepLab(num_classes=nclass, output_stride=8)

    #run model on CPU
    model.cpu()
    torch.set_num_threads(8)

    #error checking
    if not os.path.isfile(path):
        raise RuntimeError("no model found at'{}'".format(path))

    if not os.path.isfile(args.input):
        raise RuntimeError("no image found at'{}'".format(input))

    if os.path.exists(args.output):
        raise RuntimeError("Existed file or dir found at'{}'".format(
            args.output))