def forward(self, image, mask, vertex, vertex_weights):
     seg_pred, vertex_pred = self.net(image)
     loss_seg = self.criterion(seg_pred, mask)
     loss_seg = torch.mean(loss_seg.view(loss_seg.shape[0],-1),1)
     loss_vertex = smooth_l1_loss(vertex_pred, vertex, vertex_weights, reduce=False)
     precision, recall = compute_precision_recall(seg_pred, mask)
     return seg_pred, vertex_pred, loss_seg, loss_vertex, precision, recall
def validate(network, valLoader):
	lossSegTotal = 0
	lossVertexTotal = 0
	lossTotal = 0
	network.eval()
	for idx, data in enumerate(valLoader):

		with torch.no_grad():

			# Extract data and forward propagate
			image, maskGT, vertexGT, vertexWeightsGT = [d.cuda() for d in data]
			segPred, vertexPred = network(image)

			# Compute loss
			criterion = CrossEntropyLoss(reduce=False) # Imported from torch.nn
			lossSeg = criterion(segPred, maskGT)
			lossSeg = torch.mean(lossSeg.view(lossSeg.shape[0],-1),1)
			lossVertex = smooth_l1_loss(vertexPred, vertexGT, vertexWeightsGT, reduce=False)
			precision, recall = compute_precision_recall(segPred, maskGT)
			lossSeg = torch.mean(lossSeg) # Mean over batch
			lossVertex = torch.mean(lossVertex) # Mean over batch
			loss = (1-lossRatio)*lossSeg + lossRatio*lossVertex

			# Update moving average loss
			lossSegTotal = (lossSegTotal*idx + lossSeg.item())/(idx+1)
			lossVertexTotal = (lossVertexTotal*idx + lossVertex.item())/(idx+1)
			lossTotal = (lossTotal*idx + loss.item())/(idx+1)

	return lossTotal, lossVertexTotal, lossSegTotal
Exemple #3
0
def train(net, optimizer, dataloader, device, epoch):
    net.train()
    size = len(dataloader)
    for idx, data in enumerate(dataloader):
        im_data, mask, vertex_targets, vertex_weights = [
            d.to(device) for d in data
        ]

        seg_score, seg_pred, vertex_pred = net(im_data)
        loss_cls = F.cross_entropy(seg_score, mask)
        loss_vertex = smooth_l1_loss(vertex_pred, vertex_targets,
                                     vertex_weights)
        loss = loss_cls + loss_vertex
        loss_rec.update(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % print_interval == 0:
            step = epoch * size + idx
            recorder.rec_loss(loss_rec.avg, step)
            loss_rec.reset()

        if idx % rec_interval == 0:
            batch_size = im_data.shape[0]
            nrow = 5 if batch_size > 5 else batch_size
            recorder.rec_segmentation(seg_pred,
                                      num_classes=2,
                                      nrow=nrow,
                                      step=step)
            recorder.rec_vertex(vertex_pred, vertex_weights, nrow=4, step=step)
Exemple #4
0
    def forward(self, image, mask, vertex, vertex_weights, vertex_init_pert,
                vertex_init):

        vertex_pred, x2s, x4s, x8s, xfc = self.estNet(vertex_weights *
                                                      vertex_init_pert)
        seg_pred, q_pred = self.imNet(image, x2s, x4s, x8s, xfc)

        loss_q = smooth_l1_loss(q_pred, (vertex_init - vertex),
                                vertex_weights,
                                reduce=False)
        loss_vertex = smooth_l1_loss(vertex_pred,
                                     vertex_init,
                                     vertex_weights,
                                     reduce=False)
        loss = (10 * loss_vertex) + loss_q

        precision, recall = compute_precision_recall(seg_pred, mask)
        return seg_pred, vertex_pred, q_pred, loss, precision, recall
Exemple #5
0
    def compute_loss(self, seg_pred, vertex_pred, mask, vertex, vertex_weights,
                     vertex_loss_ratio):
        criterion = nn.CrossEntropyLoss(reduce=False)
        loss_seg = criterion(seg_pred, mask)
        loss_seg = torch.mean(loss_seg.view(loss_seg.shape[0], -1), 1)
        loss_vertex = smooth_l1_loss(vertex_pred,
                                     vertex,
                                     vertex_weights,
                                     reduce=False)

        return loss_seg + loss_vertex * vertex_loss_ratio, loss_seg, loss_vertex
Exemple #6
0
    def forward(self, image, mask, vertex, vertex_weights, hcoords=None):
        seg_pred, vertex_pred = self.net(image)
        loss_seg = self.criterion(seg_pred, mask)
        loss_seg = torch.mean(loss_seg.view(loss_seg.shape[0],-1),1)

        mask_pred = torch.argmax(seg_pred, dim=1, keepdim=True).float().cuda().detach()
        loss_vertex = smooth_l1_loss(vertex_pred, vertex, vertex_weights, reduce=False)
        loss_p2l, loss_voting = center_voting_loss_v1(vertex_pred, vertex_weights, mask_pred, hcoords)

        precision, recall = compute_precision_recall(seg_pred, mask)
        return seg_pred, vertex_pred, loss_seg, loss_vertex, loss_p2l, loss_voting, precision, recall
Exemple #7
0
    def forward(self, image, mask, vertex, vertex_weights):
        #喂入RGB图片,输出 两张语义分割图,kp×2张
        seg_pred, vertex_pred = self.net(image)
        #语义分割,前景 各个label,以及背景图
        loss_seg = self.criterion(seg_pred, mask)

        loss_seg = torch.mean(loss_seg.view(loss_seg.shape[0], -1), 1)
        #输入都是n张图片组,
        loss_vertex = smooth_l1_loss(vertex_pred,
                                     vertex,
                                     vertex_weights,
                                     reduce=False)

        precision, recall = compute_precision_recall(seg_pred, mask)

        return seg_pred, vertex_pred, loss_seg, loss_vertex, precision, recall
Exemple #8
0
class NetWrapper(nn.Module):
    def __init__(self, net):
        super(NetWrapper, self).__init__()
        self.net = net
        self.criterion = nn.CrossEntropyLoss(reduce=False)

    def forward(self, image, mask, vertex, vertex_weights):
<<<<<<< HEAD
        seg_pred, vertex_pred = self.net(image, mode="mapped")
=======
        seg_pred, vertex_pred = self.net(image, 'mapped')
>>>>>>> 2c722555563b8a77e36b246d82747754cf8dfae7
        loss_seg = self.criterion(seg_pred, mask)
        loss_seg = torch.mean(loss_seg.view(loss_seg.shape[0], -1), 1)
        loss_vertex = smooth_l1_loss(vertex_pred, vertex, vertex_weights, reduce=False)
        precision, recall = compute_precision_recall(seg_pred, mask)
        return seg_pred, vertex_pred, loss_seg, loss_vertex, precision, recall


class EvalWrapper(nn.Module):
    def forward(self, seg_pred, vertex_pred, use_argmax=True):
        vertex_pred = vertex_pred.permute(0, 2, 3, 1)
        b, h, w, vn_2 = vertex_pred.shape
        vertex_pred = vertex_pred.view(b, h, w, vn_2 // 2, 2)
        if use_argmax:
            mask = torch.argmax(seg_pred, 1)
        else:
            mask = seg_pred
        return ransac_voting_layer_v3(mask, vertex_pred, 512, inlier_thresh=0.99)
Exemple #9
0
        # REMOVE:
        #visualize_vertex_field(vertexGT.clone(), vertexWeightsGT.clone(), keypointIdx=2)

        # Forward propagate
        tForwardPropStart = time.time()
        segPred, vertexPred = network(image)
        tForwardPropElapsed = time.time() - tForwardPropStart

        # Compute loss
        tComputeLossStart = time.time()
        criterion = CrossEntropyLoss(reduce=False)  # Imported from torch.nn
        lossSeg = criterion(segPred, maskGT)
        lossSeg = torch.mean(lossSeg.view(lossSeg.shape[0], -1), 1)
        lossVertex = smooth_l1_loss(vertexPred,
                                    vertexGT,
                                    vertexWeightsGT,
                                    reduce=False)
        #precision, recall = compute_precision_recall(segPred, maskGT)
        lossSeg = torch.mean(lossSeg)  # Mean over batch
        lossVertex = torch.mean(lossVertex)  # Mean over batch
        loss = (1 - lossRatio) * lossSeg + lossRatio * lossVertex
        tComputeLossElapsed = time.time() - tComputeLossStart

        # Update weights
        tUpdateWeightStart = time.time()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tUpdateWeightElapsed = time.time() - tUpdateWeightStart

        # Print training loop iteration time
def train(network, trainLoader, optimizer):

	network.train()
	lossSegTotal = 0
	lossVertexTotal = 0
	lossTotal = 0

	tTrainingLoopStart = time.time()
	for idx, data in enumerate(trainLoader):
		
		# Extract data
		tExtractDataStart = time.time()
		image, maskGT, vertexGT, vertexWeightsGT = [d.cuda() for d in data]
		tExtractDataElapsed = time.time() - tExtractDataStart


		# Forward propagate
		tForwardPropStart = time.time()
		segPred, vertexPred = network(image)
		tForwardPropElapsed = time.time() - tForwardPropStart

		# Compute loss
		tComputeLossStart = time.time()
		criterion = CrossEntropyLoss(reduce=False) # Imported from torch.nn
		lossSeg = criterion(segPred, maskGT)
		lossSeg = torch.mean(lossSeg.view(lossSeg.shape[0],-1),1)
		lossVertex = smooth_l1_loss(vertexPred, vertexGT, vertexWeightsGT, reduce=False)
		#precision, recall = compute_precision_recall(segPred, maskGT)
		lossSeg = torch.mean(lossSeg) # Mean over batch
		lossVertex = torch.mean(lossVertex) # Mean over batch
		loss = (1-lossRatio)*lossSeg + lossRatio*lossVertex
		tComputeLossElapsed = time.time() - tComputeLossStart

		# # test custom score
		# nBatches = len(maskGT)
		# #[print_attributes('shape', variables=x) for x in (maskGT, vertexGT, vertexWeightsGT, segPred, vertexPred)]
		
		# for iBatch in range(nBatches):
		# 	valScores.append(custom_net_score(maskGT[iBatch:iBatch+1], vertexGT[iBatch:iBatch+1], vertexWeightsGT[iBatch:iBatch+1], segPred[iBatch:iBatch+1], vertexPred[iBatch]))

		# Update weights
		tUpdateWeightStart = time.time()
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		tUpdateWeightElapsed = time.time() - tUpdateWeightStart

		# Compute moving average losses
		lossSegTotal = (lossSegTotal*idx + lossSeg.item())/(idx+1)
		lossVertexTotal = (lossVertexTotal*idx + lossVertex.item())/(idx+1)
		lossTotal = (lossTotal*idx + loss.item())/(idx+1)

		# Print training loop iteration time(s)
		tTrainingLoopElapsed = time.time() - tTrainingLoopStart
		if (idx % int(nIterations/10))==0:
			# print('Extracting data took ' + str(tExtractDataElapsed) + ' seconds.')
			# print('Forward propagating took ' + str(tForwardPropElapsed) + ' seconds.')
			# print('Computing loss took ' + str(tComputeLossElapsed) + ' seconds.')
			# print('Updating weights took ' + str(tUpdateWeightElapsed) + ' seconds.')
			# print('Individual steps took at total of {} seconds.'.format(tExtractDataElapsed+tForwardPropElapsed+tComputeLossElapsed+tUpdateWeightElapsed))
			print('Ended training loop iteration {}/{}, elapsed time is {} seconds.'.format(idx, nIterations ,tTrainingLoopElapsed))
			if idx != 0:
				print('Expected time until end of training epoch: {} seconds'.format((nIterations/idx-1)*tTrainingLoopElapsed))

	return lossTotal, lossVertexTotal, lossSegTotal