Exemplo n.º 1
0
def main():
    args = docopt(docstr, version='v0.1')
    print(args)

    gpu0 = int(args['--gpu0'])
    im_path = args['--testIMpath']
    gt_path = args['--testGTpath']

    model = deeplab_resnet.Res_Deeplab(int(args['--NoLabels']), args['--dgf'],
                                       4, 1e-2)
    model.eval().cuda(gpu0)

    img_list = open('data/list/val.txt').readlines()
    saved_state_dict = torch.load(args['--snapshots'])
    model.load_state_dict(saved_state_dict)

    save_path = os.path.join('data', args['--exp'])
    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    max_label = int(args['--NoLabels']) - 1  # labels from 0,1, ... 20(for VOC)
    hist = np.zeros((max_label + 1, max_label + 1))
    for idx, i in enumerate(img_list):
        print('{}/{} ...'.format(idx + 1, len(img_list)))

        img = cv2.imread(os.path.join(im_path, i[:-1] + '.jpg')).astype(float)
        img_original = img.copy() / 255.0
        img[:, :, 0] = img[:, :, 0] - 104.008
        img[:, :, 1] = img[:, :, 1] - 116.669
        img[:, :, 2] = img[:, :, 2] - 122.675

        if args['--dgf']:
            inputs = [img, img_original]
        else:
            inputs = [np.zeros((513, 513, 3))]
            inputs[0][:img.shape[0], :img.shape[1], :] = img

        with torch.no_grad():
            output = model(*[
                torch.from_numpy(i[np.newaxis, :].transpose(
                    0, 3, 1, 2)).float().cuda(gpu0) for i in inputs
            ])
        if not args['--dgf']:
            interp = nn.Upsample(size=(513, 513),
                                 mode='bilinear',
                                 align_corners=True)
            output = interp(output)
            output = output[:, :, :img.shape[0], :img.shape[1]]

        output = output.cpu().data[0].numpy().transpose(1, 2, 0)
        output = np.argmax(output, axis=2)

        vis_output = decode_labels(output)
        imsave(os.path.join(save_path, i[:-1] + '.png'), vis_output)

        gt = cv2.imread(os.path.join(gt_path, i[:-1] + '.png'), 0)
        hist += fast_hist(gt.flatten(), output.flatten(), max_label + 1)

    miou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    print("Mean iou = ", np.sum(miou) / len(miou))
Exemplo n.º 2
0
	def __init__(self, num_classes):
		super(PSPNet,self).__init__()
		print("initializing model")
		init_net=deeplab_resnet.Res_Deeplab()

		state=torch.load("models/MS_DeepLab_resnet_trained_VOC.pth")
		init_net.load_state_dict(state)
		self.resnet=init_net
			

		self.layer5a = PyramidPool(2048, 512, 1)
		self.layer5b = PyramidPool(2048, 512, 2)
		self.layer5c = PyramidPool(2048, 512, 3)
		self.layer5d = PyramidPool(2048, 512, 6)

				


		self.final = nn.Sequential(
			nn.Conv2d(4096, 512, 3, padding=1, bias=False),
			nn.BatchNorm2d(512, momentum=.95),
			nn.ReLU(inplace=True),
			nn.Dropout(.1),
			nn.Conv2d(512, num_classes, 1),
		)
		
		initialize_weights(self.layer5a,self.layer5b,self.layer5c,self.layer5d,self.final)
Exemplo n.º 3
0
def main():
    args = docopt(docstr, version='v0.1')
    print(args)

    gpu0 = int(args['--gpu0'])

    model = deeplab_resnet.Res_Deeplab(21, True, 4, 1e-2)
    model.load_state_dict(torch.load(args['--snapshots']))
    model.eval().cuda(gpu0)

    im_path = args['--img_path']

    img = cv2.imread(im_path).astype(float)
    img_original = img.copy() / 255.0
    img[:, :, 0] = img[:, :, 0] - 104.008
    img[:, :, 1] = img[:, :, 1] - 116.669
    img[:, :, 2] = img[:, :, 2] - 122.675

    with torch.no_grad():
        output = model(*[torch.from_numpy(i[np.newaxis, :].transpose(0, 3, 1, 2)).float().cuda(gpu0) for i in  [img, img_original]])
    output = output.cpu().data[0].numpy().transpose(1, 2, 0)
    output = np.argmax(output, axis=2)

    vis_output = decode_labels(output)

    output_directory = os.path.dirname(im_path)
    output_name = os.path.splitext(os.path.basename(im_path))[0]
    save_path = os.path.join(output_directory, '{}_labels.png'.format(output_name))
    imsave(save_path, vis_output)
Exemplo n.º 4
0
    def __init__(self, num_classes):
        super(PSPNet, self).__init__()

        init_net = deeplab_resnet.Res_Deeplab()

        #resnet = models.resnet101(pretrained=True)

        state = torch.load("../models/MS_DeepLab_resnet_trained_VOC.pth")
        init_net.load_state_dict(state)
        self.resnet = init_net

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #m.stride = 1
                m.requires_grad = False
            if isinstance(m, nn.BatchNorm2d):
                m.requires_grad = False

        self.layer5a = PSPDec(21, 5, 18)
        self.layer5b = PSPDec(21, 5, 9)
        self.layer5c = PSPDec(21, 5, 6)
        self.layer5d = PSPDec(21, 5, 3)

        self.final = nn.Sequential(
            nn.Conv2d(41, 25, 3, padding=1, bias=False),
            nn.BatchNorm2d(25, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(25, num_classes, 1),
        )
Exemplo n.º 5
0
def convert(img_p, layers):
    caffe_model = load_caffe(img_p)

    param_provider = CaffeParamProvider(caffe_model)
    model = deeplab_resnet.Res_Deeplab(21)
    old_dict = model.state_dict()
    new_state_dict = OrderedDict()
    keys = list(model.state_dict().keys())

    for var_name in keys[:]:
        data = parse_pth_varnames(param_provider, var_name, layers)
        new_state_dict[var_name] = torch.from_numpy(data).float()

    model.load_state_dict(new_state_dict)


    o = []
    def hook(module, input, output):
        #print module
        o.append(input[0].data.numpy())

    model.Scale.conv1.register_forward_hook(hook)   #0, data
    model.Scale.bn1.register_forward_hook(hook)     #1 conv1 out
    model.Scale.relu.register_forward_hook(hook)  #2 batch norm out
    model.Scale.maxpool.register_forward_hook(hook)    #3 bn1, relu out
    model.Scale.layer1._modules['0'].conv1.register_forward_hook(hook)   #4, pool1 out
    model.Scale.layer1._modules['1'].conv1.register_forward_hook(hook) #5, res2a out
    model.Scale.layer5.conv2d_list._modules['0'].register_forward_hook(hook) #6, res5c out

    model.eval()
    output = model(Variable(torch.from_numpy(img_p[np.newaxis, :].transpose(0,3,1,2)).float(),volatile=True))

    interp = nn.UpsamplingBilinear2d(size=(321, 321))
    output_temp = interp(output[3]).cpu().data[0].numpy()
    output_temp = output_temp.transpose(1,2,0)
    output_temp = np.argmax(output_temp,axis = 2)
    #plt.imshow(output_temp)
    #plt.show()
    dist_(caffe_model.blobs['data'].data,o[0])
    dist_(caffe_model.blobs['conv1'].data,o[3])
    dist_(caffe_model.blobs['pool1'].data,o[4])
    dist_(caffe_model.blobs['res2a'].data,o[5])
    dist_(caffe_model.blobs['res5c'].data,o[6])
    dist_(caffe_model.blobs['fc1_voc12'].data,output[0].data.numpy())
    dist_(caffe_model.blobs['fc1_voc12_res075_interp'].data,output[1].data.numpy())
    dist_(caffe_model.blobs['fc1_voc12_res05'].data,output[2].data.numpy())
    dist_(caffe_model.blobs['fc_fusion'].data,output[3].data.numpy())

    print(('input image shape',img_p[np.newaxis, :].transpose(0,3,1,2).shape))
    print('output shapes -')
    for a in output:
        print((a.data.numpy().shape))

    torch.save(model.state_dict(),'data/MS_DeepLab_resnet_trained_VOC.pth')
def convert(img_p, layers):
    caffe_model = load_caffe(img_p)

    param_provider = CaffeParamProvider(caffe_model)
    model = deeplab_resnet.Res_Deeplab(21)
    old_dict = model.state_dict()
    new_state_dict = OrderedDict()
    keys = model.state_dict().keys()
    for var_name in keys[:]:
        data = parse_pth_varnames(param_provider, var_name, layers)
        new_state_dict[var_name] = torch.from_numpy(data).float()

    model.load_state_dict(new_state_dict)

    o = []

    def hook(module, input, output):
        #print module
        o.append(input[0].data.numpy())

    model.Scale.conv1.register_forward_hook(hook)  #0, data
    model.Scale.bn1.register_forward_hook(hook)  #1 conv1 out
    model.Scale.relu.register_forward_hook(hook)  #2 batch norm out
    model.Scale.maxpool.register_forward_hook(hook)  #3 bn1, relu out
    model.Scale.layer1._modules['0'].conv1.register_forward_hook(
        hook)  #4, pool1 out
    model.Scale.layer1._modules['1'].conv1.register_forward_hook(
        hook)  #5, res2a out
    model.Scale.layer5.conv2d_list._modules['0'].register_forward_hook(
        hook)  #6, res5c out

    model.eval()
    output = model(
        Variable(torch.from_numpy(img_p[np.newaxis, :].transpose(0, 3, 1,
                                                                 2)).float(),
                 volatile=True))

    dist_(caffe_model.blobs['data'].data, o[0])
    dist_(caffe_model.blobs['conv1'].data, o[3])
    dist_(caffe_model.blobs['pool1'].data, o[4])
    dist_(caffe_model.blobs['res2a'].data, o[5])
    dist_(caffe_model.blobs['res5c'].data, o[6])
    dist_(caffe_model.blobs['fc1_voc12'].data, output[3].data.numpy())

    print 'input image shape', img_p[np.newaxis, :].transpose(0, 3, 1, 2).shape
    print 'output shapes -'
    for a in output:
        print a.data.numpy().shape

    torch.save(model.state_dict(),
               'data/MS_DeepLab_resnet_pretrained_COCO_init.pth')
    def __init__(self, args):
        # import most params
        self.batch_size = 1
        self.weight_decay = float(args['--wtDecay'])
        self.base_lr = float(args['--lr'])
        self.epoch_num = int(args['--epoch'])
        self.label_num = int(args['--NoLabels'])
        self.GPU = int(args['--GPUID'])
        self.model = deeplab_resnet.Res_Deeplab(int(args['--NoLabels']))

        self.model.load_state_dict(torch.load(args['--pretrain']))
        self.model.float()
        self.model.eval()
        # if torch.cuda.device_count() > 1:
        #     print("Let's use", torch.cuda.device_count(), "GPUs!")
        #     # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        #     self.model = nn.DataParallel(self.model)

        self.model.cuda(self.GPU)

        # data set
        self.datasets = dataset.create_dataset(args['--ListPath'])

        # optimization
        self.criterion = nn.CrossEntropyLoss(
            ignore_index=255)  # use a Classification Cross-Entropy loss
        self.optimizer = optim.SGD(
            [{
                'params': get_1x_lr_params_NOscale(self.model),
                'lr': self.base_lr
            }, {
                'params': get_10x_lr_params(self.model),
                'lr': 10 * self.base_lr
            }],
            lr=self.base_lr,
            momentum=0.9,
            weight_decay=self.weight_decay)

        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         5,
                                                         gamma=0.5,
                                                         last_epoch=-1)
        self.optimizer.zero_grad()

        # data writer
        self.writer = SummaryWriter(log_dir=args['--summaryPath'])
Exemplo n.º 8
0
    This generator returns all the parameters for the last layer of the net,
    which does the classification of pixel into classes
    """

    b = []
    b.append(model.Scale.layer5.parameters())

    for j in range(len(b)):
        for i in b[j]:
            yield i


if not os.path.exists('data/snapshots'):
    os.makedirs('data/snapshots')

model = deeplab_resnet.Res_Deeplab(int(args['--NoLabels']))

saved_state_dict = torch.load(
    'data/MS_DeepLab_resnet_pretrained_COCO_init.pth')
if int(args['--NoLabels']) != 21:
    for i in saved_state_dict:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.')
        if i_parts[1] == 'layer5':
            saved_state_dict[i] = model.state_dict()[i]

model.load_state_dict(saved_state_dict)

max_iter = int(args['--maxIter'])
batch_size = 1
weight_decay = float(args['--wtDecay'])
Exemplo n.º 9
0
def main():

    opt = parser.parse_args()
    print(opt)

    trainloader = datasets.init_data.load_data(opt)

    model = deeplab_resnet.Res_Deeplab(21)
    saved_state_dict = torch.load(
        '/data/MS_DeepLab_resnet_pretrained_COCO_init.pth')
    model.load_state_dict(saved_state_dict)

    max_iter = opt.maxIter
    batch_size = 1
    weight_decay = opt.wtDecay
    base_lr = opt.lr

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD([{
        'params': get_1x_lr_params_NOscale(model),
        'lr': opt.base_lr
    }, {
        'params': get_10x_lr_params(model),
        'lr': 10 * base_lr
    }],
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=weight_decay)

    for i, data in enumerate(trainloader, 0):
        images, labels = data
        images = Variable(images).cuda()
        out = model(images)

        loss = loss_calc(out[0], label[0], gpu0)
        iter_size = int(args['--iterSize'])
        for i in range(len(out) - 1):
            loss = loss + loss_calc(out[i + 1], label[i + 1], gpu0)

        loss = loss / iter_size
        loss.backward()

        if iter % 1 == 0:
            print 'iter = ', iter, 'of', max_iter, 'completed, loss = ', iter_size * (
                loss.data.cpu().numpy())

        if iter % iter_size == 0:
            optimizer.step()
            lr_ = lr_poly(base_lr, iter, max_iter, 0.9)
            print '(poly lr policy) learning rate', lr_
            optimizer = optim.SGD([{
                'params': get_1x_lr_params_NOscale(model),
                'lr': lr_
            }, {
                'params': get_10x_lr_params(model),
                'lr': 10 * lr_
            }],
                                  lr=lr_,
                                  momentum=0.9,
                                  weight_decay=weight_decay)
            optimizer.zero_grad()

        if iter % 1000 == 0 and iter != 0:
            print 'taking snapshot ...'
            torch.save(model.state_dict(),
                       'data/snapshots/VOC12_scenes_' + str(iter) + '.pth')
Exemplo n.º 10
0
def main():

    global opt, model
    opt = parser.parse_args()
    print opt

    ##########configs########
    #     isSegReg = False
    #     isDiceLoss = True # for dice loss
    #     isSoftmaxLoss = True #for softmax loss
    #     isResidualEnhancement = False #using ensemble learning to enhance residual learning
    #     isViewExpansion = True #using dilation to expand receptive filed
    #     isAdLoss = True #for adverarail loss
    #     isSpatialDropOut = False
    #     isFocalLoss = False #we use focal loss to escale training dominated by easy examples
    #     isSampleImportanceFromAd = False #we set batch sample importance using the loss from adversarial training
    #     dropoutRate = 0.25
    #     lambdaAD = 0 # the coefficients before the Adversarial training
    #     adImportance = 0 # attention from adversarial training
    #     how2normalize = 3 #1. mu/(max-min); 2. mu/(percent_99 - percent_1); 3. mu/std
    #     lr = 1e-4 #one of the most important hyper-parameters:
    #     prefixModelName = 'Segmentor_wdice_wce_lrdce_viewExpansion_1111_'
    #     prefixPredictedFN = 'preSub_wdice_wce_lrdce_viewExpansion_1111_'
    #     showTrainLossEvery = 100
    #     showTestPerformanceEvery = 2000
    #     decLREvery = 25000 #decrease learning rate every xxx iterations
    #     saveModelEvery = 2000
    #     numofIters = 200000

    ##########configs########

    if opt.isSegReg:
        netG = ResSegRegNet()
    elif opt.isContourLoss:
        netG = ResSegContourNet(isRandomConnection=opt.isResidualEnhancement,
                                isSmallDilation=opt.isViewExpansion,
                                isSpatialDropOut=opt.isSpatialDropOut,
                                dropoutRate=opt.dropoutRate)
    else:
        netG = ResSegNet(isRandomConnection=opt.isResidualEnhancement,
                         isSmallDilation=opt.isViewExpansion,
                         isSpatialDropOut=opt.isSpatialDropOut,
                         dropoutRate=opt.dropoutRate)

#     netG =PSPNet(num_classes=4)
    netG = deeplab_resnet.Res_Deeplab(NoLabels=4)
    #netG.apply(weights_init)
    netG = netG.cuda()

    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()

    params = list(netG.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    #     optimizerG =optim.SGD(netG.parameters(),lr=1e-2)
    #     optimizerG =optim.Adam(netG.parameters(),lr=opt.lr)
    optimizerG = optim.Adam(filter(lambda p: p.requires_grad,
                                   netG.parameters()),
                            lr=opt.lr)

    #     optimizerD =optim.SGD(netD.parameters(),lr=1e-4)
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr)

    criterion_MSE = nn.MSELoss()
    given_weight = torch.FloatTensor([1, 4, 8, 8])
    given_weight = given_weight.cuda()
    #     criterion_NLL2D = nn.NLLLoss2d(weight=given_weight)
    criterion_CE2D = CrossEntropy2d(weight=given_weight)

    criterion_BCE2D = CrossEntropy2d()  #for contours

    #     criterion_dice = DiceLoss4Organs(organIDs=[1,2,3], organWeights=[1,1,1])
    #     criterion_dice = WeightedDiceLoss4Organs()
    criterion_dice = myWeightedDiceLoss4Organs(organIDs=[0, 1, 2, 3],
                                               organWeights=given_weight)

    criterion_focal = myFocalLoss(4, alpha=given_weight, gamma=2)

    criterion = nn.BCELoss()
    criterion = criterion.cuda()
    criterion_dice = criterion_dice.cuda()
    criterion_MSE = criterion_MSE.cuda()
    criterion_CE2D = criterion_CE2D.cuda()
    criterion_BCE2D = criterion_BCE2D.cuda()
    criterion_focal = criterion_focal.cuda()
    softmax2d = nn.Softmax2d()
    #     inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
    #     targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
    #     trainset=data_utils.TensorDataset(inputs, targets)
    #     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    #     inputs=torch.randn(1000,1,32,32)
    #     targets=torch.LongTensor(1000)

    path_test = '/home/dongnie/warehouse/mrs_data'
    path_test = '/shenlab/lab_stor5/dongnie/pelvic'
    path_patients_h5 = '/home/dongnie/warehouse/BrainEstimation/brainH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourBatchH5'
    path_patients_h5 = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSeg2D64H5'
    #     path_patients_h5 = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5'
    #     path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegPartH5/' #only contains 1-15
    path_patients_h5_test = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourH5Test'
    path_patients_h5_test = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSeg2D64H5Test'
    #     path_patients_h5_test ='/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5Test'

    #     batch_size = 10
    if opt.isSegReg:
        data_generator = Generator_2D_slices_variousKeys(
            path_patients_h5,
            opt.batchSize,
            inputKey='dataMR2D',
            outputKey='dataSeg2D',
            regKey1='dataBladder2D',
            regKey2='dataProstate2D',
            regKey3='dataRectum2D')
    elif opt.isContourLoss:
        data_generator = Generator_2D_slicesV1(path_patients_h5,
                                               opt.batchSize,
                                               inputKey='dataMR2D',
                                               segKey='dataSeg2D',
                                               contourKey='dataContour2D')
    else:
        data_generator = Generator_2D_slices(path_patients_h5,
                                             opt.batchSize,
                                             inputKey='dataMR2D',
                                             outputKey='dataSeg2D')

    data_generator_test = Generator_2D_slices(path_patients_h5_test,
                                              opt.batchSize,
                                              inputKey='dataMR2D',
                                              outputKey='dataSeg2D')

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            #             opt.start_epoch = checkpoint["epoch"] + 1
            #             netG.load_state_dict(checkpoint["model"].state_dict())
            opt.start_epoch = 4999 + 1
            netG.load_state_dict(checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters + 1):
        #print('iter %d'%iter)

        if opt.isSegReg:
            inputs, labels, regGT1, regGT2, regGT3 = data_generator.next()
        elif opt.isContourLoss:
            inputs, labels, contours = data_generator.next()
        else:
            inputs, labels = data_generator.next()

        labels = np.squeeze(labels)
        labels = zoomImages(labels, rate=opt.zoomRate)
        labels = labels.astype(int)

        if opt.isContourLoss:
            contours = np.squeeze(contours)
            contours = contours.astype(int)
            contours = torch.from_numpy(contours)
            contours = contours.cuda()
            contours = Variable(contours)

        inputs = torch.from_numpy(inputs)
        labels = torch.from_numpy(labels)
        inputs = inputs.cuda()
        labels = labels.cuda()
        #we should consider different data to train

        #wrap them into Variable
        inputs, labels = Variable(inputs), Variable(labels)

        #zero the parameter gradients
        #netD.zero_grad()
        if opt.isAdLoss:
            #forward + backward +optimizer
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(inputs)
            outputG = softmax2d(outputG)  #batach
            #         print 'outputG: ',outputG.size(),'labels: ',labels.size()
            #         print 'outputG: ', outputG.data[0].size()
            outputG = outputG.data.max(1)[1]
            #outputG = torch.squeeze(outputG) #[N,C,W,H]
            labels = labels.unsqueeze(1)  #expand the 1st dim
            #         print 'outputG: ',outputG.size(),'labels: ',labels.size()
            outputR = labels.type(torch.FloatTensor).cuda()  #output_Real
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
            outputD_real = netD(outputR)
            #         print 'size outputG: ',outputG.unsqueeze(1).size()
            outputD_fake = netD(outputG.unsqueeze(1))

            ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
            netD.zero_grad()
            batch_size = inputs.size(0)
            #print(inputs.size())
            #train with real data
            #         real_label = torch.FloatTensor(batch_size)
            #         real_label.data.resize_(batch_size).fill_(1)
            real_label = torch.ones(batch_size, 1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion(outputD_real, real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size, 1)
            #         fake_label = torch.FloatTensor(batch_size)
            #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion(outputD_fake, fake_label)
            loss_fake.backward()

            lossD = loss_real + loss_fake

            optimizerD.step()

        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))

        #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
        #angel of equation (note the max and min difference for generator and discriminator)
        if opt.isAdLoss:
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(inputs)
            outputG = outputG.data.max(1)[1]
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
            #         print 'outputG shape, ',outputG.size()

            outputD = netD(outputG.unsqueeze(1))
            averProb = outputD.data.cpu().mean()
            #             print 'prob: ',averProb
            #             adImportance = computeAttentionWeight(averProb)
            adImportance = computeSampleAttentionWeight(averProb)
            lossG_D = opt.lambdaAD * criterion(
                outputD, real_label
            )  #note, for generator, the label for outputG is real
            lossG_D.backward(retain_graph=True)

        if opt.isSegReg:
            outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
        elif opt.isContourLoss:
            outputG, outputContour = netG(inputs)
        else:
            outputG = netG(
                inputs)  #here I am not sure whether we should use twice or not
        netG.zero_grad()

        outputG = outputG[0]
        #         print 'outputG size: ',outputG.size(),' label size: ',labels.size()
        if opt.isFocalLoss:
            lossG_focal = criterion_focal(outputG, torch.squeeze(labels))
            lossG_focal.backward(retain_graph=True)  #compute gradients

        if opt.isSoftmaxLoss:
            if opt.isSampleImportanceFromAd:
                lossG_G = (1 + adImportance) * criterion_CE2D(
                    outputG, torch.squeeze(labels))
            else:
                lossG_G = criterion_CE2D(outputG, torch.squeeze(labels))

            lossG_G.backward(retain_graph=True)  #compute gradients

        if opt.isContourLoss:
            lossG_contour = criterion_BCE2D(outputContour, contours)
            lossG_contour.backward(retain_graph=True)
#         criterion_dice(outputG,torch.squeeze(labels))
#         print 'hahaN'
        if opt.isSegReg:
            lossG_Reg1 = criterion_MSE(outputReg1, regGT1)
            lossG_Reg2 = criterion_MSE(outputReg2, regGT2)
            lossG_Reg3 = criterion_MSE(outputReg3, regGT3)
            lossG_Reg = lossG_Reg1 + lossG_Reg2 + lossG_Reg3
            lossG_Reg.backward()

        if opt.isDiceLoss:
            #             print 'isDiceLoss line278'
            #             criterion_dice = myWeightedDiceLoss4Organs(organIDs=[0,1,2,3], organWeights=[1,4,8,6])
            if opt.isSampleImportanceFromAd:
                loss_dice = (1 + adImportance) * criterion_dice(
                    outputG, torch.squeeze(labels))
            else:
                loss_dice = criterion_dice(outputG, torch.squeeze(labels))
#             loss_dice = myDiceLoss4Organs(outputG,torch.squeeze(labels)) #succeed
#             loss_dice.backward(retain_graph=True) #compute gradients for dice loss
            loss_dice.backward()  #compute gradients for dice loss

        #lossG_D.backward()

        #for other losses, we can define the loss function following the pytorch tutorial

        optimizerG.step()  #update network parameters
        #         print 'gradients of parameters****************************'
        #         [x.grad.data for x in netG.parameters()]
        #         print x.grad.data[0]
        #         print '****************************'
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg and opt.isFocalLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[
                0] + lossG_Reg + lossG_focal
        if opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[
                0] + lossG_Reg
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[
                0] + lossG_Reg
        elif opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G + lossG_Reg
        elif opt.isDiceLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0] + lossG_Reg
        elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[0]
        elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[0]
        elif opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G
        elif opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal
        elif opt.isDiceLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0]
        elif opt.isSoftmaxLoss:
            lossG = lossG_G
        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG.data[0]
        #         print 'running_loss is ',running_loss,' type: ',type(running_loss)

        #         print type(outputD_fake.cpu().data[0].numpy())

        if iter % opt.showTrainLossEvery == 0:  #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
            if opt.isAdLoss:
                print 'the outputD_real for iter {}'.format(
                    iter), ' is ', outputD_real.cpu().data[0].numpy()[0]
                print 'the outputD_fake for iter {}'.format(
                    iter), ' is ', outputD_fake.cpu().data[0].numpy()[0]
                print 'loss for discriminator at iter ', iter, ' is %f' % lossD.data[
                    0]
#             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.3f' % (
                iter - 100 + 1, iter, running_loss / 100)

            print 'total loss for generator at iter ', iter, ' is %f' % lossG.data[
                0]
            if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
                print 'lossG_D, lossG_G and loss_dice loss_Reg are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0], loss_dice.data[0],
                    lossG_Reg.data[0])
            elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D, lossG_G and loss_dice are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0], loss_dice.data[0])
            elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D, lossG_focal and loss_dice are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_focal.data[0], loss_dice.data[0])
            elif opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D and lossG_G are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0])
            elif opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D and lossG_focal are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_focal.data[0])
            elif opt.isDiceLoss and opt.isAdLoss:
                print 'lossG_D and loss_dice are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], loss_dice.data[0])
            elif opt.isSoftmaxLoss:
                print ' lossG_G are %.2f respectively.' % (lossG_G.data[0])

            if opt.isContourLoss:
                print 'lossG_contour is {}'.format(lossG_contour.data[0])

            print 'cost time for iter [%d, %d] is %.2f' % (
                iter - 100 + 1, iter, time.time() - start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter % opt.saveModelEvery == 0:  #save the model
            torch.save(netG.state_dict(), opt.prefixModelName + '%d.pt' % iter)
            print 'save model: ' + opt.prefixModelName + '%d.pt' % iter

        if iter % opt.decLREvery == 0 and iter > 0:
            opt.lr = opt.lr * 0.1
            adjust_learning_rate(optimizerG, opt.lr)
            print 'now the learning rate is {}'.format(opt.lr)

        if iter % opt.showTestPerformanceEvery == 0:  #test one subject
            # to test on the validation dataset in the format of h5
            inputs, labels = data_generator_test.next()
            labels = np.squeeze(labels)
            labels = zoomImages(labels, rate=opt.zoomRate)
            labels = labels.astype(int)
            inputs = torch.from_numpy(inputs)
            labels = torch.from_numpy(labels)
            inputs = inputs.cuda()
            labels = labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(
                    inputs
                )  #here I am not sure whether we should use twice or not
            outputG = outputG[0]
            lossG_G = criterion_CE2D(outputG, torch.squeeze(labels))
            loss_dice = criterion_dice(outputG, torch.squeeze(labels))
            print '.......come to validation stage: iter {}'.format(
                iter), '........'
            print 'lossG_G and loss_dice are %.2f and %.2f respectively.' % (
                lossG_G.data[0], loss_dice.data[0])

            ####release all the unoccupied memory####
            torch.cuda.empty_cache()

            mr_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'img50_nocrop.nii.gz'))
            ct_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'img50_label_nie_nocrop.nii.gz'))

            mrnp = sitk.GetArrayFromImage(mr_test_itk)
            mu = np.mean(mrnp)

            ctnp = sitk.GetArrayFromImage(ct_test_itk)

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize == 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu) / std
                print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                    mrnp)

#             full image version with average over the overlapping regions
#             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)
            if opt.how2normalize == 4:
                maxV, minV = np.percentile(mrnp, [99.2, 1])
                print 'maxV is: ', np.ndarray.max(mrnp)
                mrnp[np.where(mrnp > maxV)] = maxV
                print 'maxV is: ', np.ndarray.max(mrnp)
                mu = np.mean(mrnp)
                std = np.std(mrnp)
                mrnp = (mrnp - mu) / std
                print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                    mrnp)

            matFA = mrnp
            matGT = ctnp
            matGT = zoomImages(matGT, rate=opt.zoomRate)
            matOut, _ = testOneSubject(matFA, matGT, 4, [3, 64, 64], [1, 9, 9],
                                       [1, 9, 9], netG,
                                       opt.prefixModelName + '%d.pt' % iter)
            ct_estimated = np.zeros(
                [ctnp.shape[0], ctnp.shape[1], ctnp.shape[2]])
            print 'matOut shape: ', matOut.shape
            ct_estimated = matOut
            tmp_ctnp = ctnp
            ctnp = zoomImages(tmp_ctnp, opt.zoomRate)
            #             ct_estimated[:,y1:y2,x1:x2] = matOut

            ct_estimated = np.rint(ct_estimated)
            ct_estimated = denoiseImg(ct_estimated,
                                      kernel=np.ones((20, 20, 20)))
            diceBladder = dice(ct_estimated, ctnp, 1)
            diceProstate = dice(ct_estimated, ctnp, 2)
            diceRectumm = dice(ct_estimated, ctnp, 3)

            print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
            print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
            print 'dice1 = ', diceBladder, ' dice2= ', diceProstate, ' dice3= ', diceRectumm
            volout = sitk.GetImageFromArray(ct_estimated)
            sitk.WriteImage(
                volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')


#             netG.save_state_dict('Segmentor_model_%d.pt'%iter)
#             netD.save_state_dic('Discriminator_model_%d.pt'%iter)

    print('Finished Training')
Exemplo n.º 11
0
import sys

import numpy as np

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import deeplab_resnet
import torch.nn as nn

import cv2
from config import Config

max_label = Config.class_num
gpu0 = 0
model = deeplab_resnet.Res_Deeplab(max_label)
model.eval()
counter = 0
model.cuda(gpu0)

saved_state_dict = torch.load(
    os.path.join(Config.model_path, Config.model_name + "_20000.pth"))
model.load_state_dict(saved_state_dict)

classes = np.array(('background', 'robot_hand', 'inhand_object'))

colormap = [(0, 0, 0), (0.5, 0, 0),
            (0, 0.5, 0)] + [(0, 0, 0)] * (256 - len(classes))
colormap = [colormap]
colormap = np.array(colormap) * 255
colormap = colormap.astype(np.uint8)
Exemplo n.º 12
0
def main():
    args = docopt(docstr, version='v0.1')
    print(args)

    cudnn.enabled = True
    gpu0 = int(args['--gpu0'])
    base_lr = float(args['--lr'])
    max_iter = int(args['--maxIter'])
    iter_size = int(args['--iterSize'])
    weight_decay = float(args['--wtDecay'])

    if not os.path.exists('data/' + args['--snapshots']):
        os.makedirs('data/' + args['--snapshots'])

    model = deeplab_resnet.Res_Deeplab(int(args['--NoLabels']), args['--dgf'],
                                       4, 1e-2)

    if args['--ft']:
        saved_state_dict = torch.load(args['--ft_model_path'])
    else:
        saved_state_dict = torch.load(
            'data/MS_DeepLab_resnet_pretrained_COCO_init.pth')
    model_dict = model.state_dict()
    model_dict.update(saved_state_dict)
    model.load_state_dict(model_dict)

    model.float().eval().cuda(gpu0)

    optimizer = optim.SGD([{
        'params': get_1x_lr_params_NOscale(model),
        'lr': base_lr
    }, {
        'params': get_10x_lr_params(args, model),
        'lr': 10 * base_lr
    }],
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=weight_decay)
    optimizer.zero_grad()

    img_list = read_file(args['--LISTpath'])
    data_list = []
    # make list for 10 epocs, though we will only use the first max_iter*batch_size entries of this list
    for i in range(10):
        np.random.shuffle(img_list)
        data_list.extend(img_list)
    data_gen = chunker(data_list, 1)

    for iter in range(max_iter + 1):
        inputs, label = get_data_from_chunk_v2(args, next(data_gen))
        inputs = [Variable(input).cuda(gpu0) for input in inputs]

        loss = loss_calc(model(*inputs), label, gpu0) / iter_size
        loss.backward()

        if iter % 1 == 0:
            print('iter = ', iter, 'of', max_iter, 'completed, loss = ',
                  iter_size * (loss.data.cpu().numpy()))

        if iter % iter_size == 0:
            optimizer.step()
            lr_ = lr_poly(base_lr, iter, max_iter, 0.9)
            print('(poly lr policy) learning rate', lr_)
            optimizer = optim.SGD([{
                'params': get_1x_lr_params_NOscale(model),
                'lr': lr_
            }, {
                'params': get_10x_lr_params(args, model),
                'lr': 10 * lr_
            }],
                                  lr=lr_,
                                  momentum=0.9,
                                  weight_decay=weight_decay)
            optimizer.zero_grad()

        if iter % 1000 == 0 and iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(), 'data/' + args['--snapshots'] +
                '/VOC12_scenes_' + str(iter) + '.pth')
Exemplo n.º 13
0
    """
    This generator returns all the parameters for the last layer of the net,
    which does the classification of pixel into classes
    """

    b = []
    b.append(model.Scale.layer5.parameters())

    for j in range(len(b)):
        for i in b[j]:
            yield i

if not os.path.exists(Config.snapshot_path):
    os.makedirs(Config.snapshot_path)

model = deeplab_resnet.Res_Deeplab(args.NoLabels)

saved_state_dict = torch.load(os.path.join(Config.model_path, "MS_DeepLab_resnet_pretrained_COCO_init.pth"))

if args.NoLabels != 21:

    for i in saved_state_dict:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.')
        if i_parts[1]=='layer5':
            saved_state_dict[i] = model.state_dict()[i]

model.load_state_dict(saved_state_dict)

max_iter = args.maxIter
batch_size = 1