def run(self, images): """ input: images: tensor[batch(1), 1, H, W] """ from Train_model_heatmap import Train_model_heatmap from utils.var_dim import toNumpy train_agent = Train_model_heatmap with torch.no_grad(): outs = self.net(images) semi = outs['semi'] self.outs = outs channel = semi.shape[1] if channel == 64: heatmap = train_agent.flatten_64to1(semi, cell_size=self.cell_size) elif channel == 65: heatmap = flattenDetection(semi, tensor=True) heatmap_np = toNumpy(heatmap) self.heatmap = heatmap_np return heatmap pass
def process_output(self, sp_processer): """ input: N: number of points return: -- type: tensorFloat pts: tensor [batch, N, 2] (no grad) (x, y) pts_offset: tensor [batch, N, 2] (grad) (x, y) pts_desc: tensor [batch, N, 256] (grad) """ from utils.utils import flattenDetection # from models.model_utils import pred_soft_argmax, sample_desc_from_points output = self.output semi = output['semi'] desc = output['desc'] # flatten heatmap = flattenDetection(semi) # [batch_size, 1, H, W] # nms heatmap_nms_batch = sp_processer.heatmap_to_nms(heatmap, tensor=True) # extract offsets outs = sp_processer.pred_soft_argmax(heatmap_nms_batch, heatmap) residual = outs['pred'] # extract points outs = sp_processer.batch_extract_features(desc, heatmap_nms_batch, residual) # output.update({'heatmap': heatmap, 'heatmap_nms': heatmap_nms, 'descriptors': descriptors}) output.update(outs) self.output = output return output
def run(self, inp, onlyHeatmap=False, train=True): """ Process a numpy image to extract points and descriptors. Input img - HxW tensor float32 input image in range [0,1]. Output corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. desc - 256xN numpy array of corresponding unit normalized descriptors. heatmap - HxW numpy heatmap in range [0,1] of point confidences. """ # assert img.ndim == 2, 'Image must be grayscale.' # assert img.dtype == np.float32, 'Image must be float32.' # H, W = img.shape[0], img.shape[1] # inp = img.copy() # inp = (inp.reshape(1, H, W)) # inp = torch.from_numpy(inp) # inp = torch.autograd.Variable(inp).view(1, 1, H, W) # if self.cuda: inp = inp.to(self.device) batch_size, H, W = inp.shape[0], inp.shape[2], inp.shape[3] if train: # outs = self.net.forward(inp, subpixel=self.subpixel) outs = self.net.forward(inp) # semi, coarse_desc = outs[0], outs[1] semi, coarse_desc = outs['semi'], outs['desc'] else: # Forward pass of network. with torch.no_grad(): # outs = self.net.forward(inp, subpixel=self.subpixel) outs = self.net.forward(inp) # semi, coarse_desc = outs[0], outs[1] semi, coarse_desc = outs['semi'], outs['desc'] # as tensor from utils.utils import labels2Dto3D, flattenDetection from utils.d2s import DepthToSpace # flatten detection heatmap = flattenDetection(semi, tensor=True) self.heatmap = heatmap # depth2space = DepthToSpace(8) # print(semi.shape) # heatmap = depth2space(semi[:,:-1,:,:]).squeeze(0) ## need to change for batches if onlyHeatmap: return heatmap # extract keypoints # pts = [self.getPtsFromHeatmap(heatmap[i,:,:,:].cpu().detach().numpy().squeeze()).transpose() for i in range(batch_size)] # pts = [self.getPtsFromHeatmap(heatmap[i,:,:,:].cpu().detach().numpy().squeeze()) for i in range(batch_size)] # print("heapmap shape: ", heatmap.shape) pts = [ self.getPtsFromHeatmap(heatmap[i, :, :, :].cpu().detach().numpy()) for i in range(batch_size) ] self.pts = pts if self.subpixel: labels_res = outs[2] self.pts_subpixel = [ self.subpixel_predict(toNumpy(labels_res[i, ...]), pts[i]) for i in range(batch_size) ] ''' pts: list [batch_size, np(N_i, 3)] -- each point (x, y, probability) ''' # interpolate description ''' coarse_desc: tensor (Batch_size, 256, Hc, Wc) dense_desc: tensor (batch_size, 256, H, W) ''' # m = nn.Upsample(scale_factor=(1, self.cell, self.cell), mode='bilinear') dense_desc = nn.functional.interpolate(coarse_desc, scale_factor=(self.cell, self.cell), mode='bilinear') # norm the descriptor def norm_desc(desc): dn = torch.norm(desc, p=2, dim=1) # Compute the norm. desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize. return desc dense_desc = norm_desc(dense_desc) # extract descriptors dense_desc_cpu = dense_desc.cpu().detach().numpy() # pts_desc = [dense_desc_cpu[i, :, pts[i][:, 1].astype(int), pts[i][:, 0].astype(int)] for i in range(len(pts))] pts_desc = [ dense_desc_cpu[i, :, pts[i][1, :].astype(int), pts[i][0, :].astype(int)].transpose() for i in range(len(pts)) ] if self.subpixel: return self.pts_subpixel, pts_desc, dense_desc, heatmap return pts, pts_desc, dense_desc, heatmap
def get_heatmap(self, semi, det_loss_type="softmax"): if det_loss_type == "l2": heatmap = self.flatten_64to1(semi) else: heatmap = flattenDetection(semi) return heatmap
def add2tensorboard_nms(self, img, labels_2D, semi, task="training", batch_size=1): """ # deprecated: :param img: :param labels_2D: :param semi: :param task: :param batch_size: :return: """ from utils.utils import getPtsFromHeatmap from utils.utils import box_nms boxNms = False n_iter = self.n_iter nms_dist = self.config["model"]["nms"] conf_thresh = self.config["model"]["detection_threshold"] # print("nms_dist: ", nms_dist) precision_recall_list = [] precision_recall_boxnms_list = [] for idx in range(batch_size): semi_flat_tensor = flattenDetection(semi[idx, :, :, :]).detach() semi_flat = toNumpy(semi_flat_tensor) semi_thd = np.squeeze(semi_flat, 0) pts_nms = getPtsFromHeatmap(semi_thd, conf_thresh, nms_dist) semi_thd_nms_sample = np.zeros_like(semi_thd) semi_thd_nms_sample[pts_nms[1, :].astype(np.int), pts_nms[0, :].astype(np.int)] = 1 label_sample = torch.squeeze(labels_2D[idx, :, :, :]) # pts_nms = getPtsFromHeatmap(label_sample.numpy(), conf_thresh, nms_dist) # label_sample_rms_sample = np.zeros_like(label_sample.numpy()) # label_sample_rms_sample[pts_nms[1, :].astype(np.int), pts_nms[0, :].astype(np.int)] = 1 label_sample_nms_sample = label_sample if idx < 5: result_overlap = img_overlap( np.expand_dims(label_sample_nms_sample, 0), np.expand_dims(semi_thd_nms_sample, 0), toNumpy(img[idx, :, :, :]), ) self.writer.add_image( task + "-detector_output_thd_overlay-NMS" + "/%d" % idx, result_overlap, n_iter, ) assert semi_thd_nms_sample.shape == label_sample_nms_sample.size() precision_recall = precisionRecall_torch( torch.from_numpy(semi_thd_nms_sample), label_sample_nms_sample) precision_recall_list.append(precision_recall) if boxNms: semi_flat_tensor_nms = box_nms(semi_flat_tensor.squeeze(), nms_dist, min_prob=conf_thresh).cpu() semi_flat_tensor_nms = (semi_flat_tensor_nms >= conf_thresh).float() if idx < 5: result_overlap = img_overlap( np.expand_dims(label_sample_nms_sample, 0), semi_flat_tensor_nms.numpy()[np.newaxis, :, :], toNumpy(img[idx, :, :, :]), ) self.writer.add_image( task + "-detector_output_thd_overlay-boxNMS" + "/%d" % idx, result_overlap, n_iter, ) precision_recall_boxnms = precisionRecall_torch( semi_flat_tensor_nms, label_sample_nms_sample) precision_recall_boxnms_list.append(precision_recall_boxnms) precision = np.mean([ precision_recall["precision"] for precision_recall in precision_recall_list ]) recall = np.mean([ precision_recall["recall"] for precision_recall in precision_recall_list ]) self.writer.add_scalar(task + "-precision_nms", precision, n_iter) self.writer.add_scalar(task + "-recall_nms", recall, n_iter) print("-- [%s-%d-fast NMS] precision: %.4f, recall: %.4f" % (task, n_iter, precision, recall)) if boxNms: precision = np.mean([ precision_recall["precision"] for precision_recall in precision_recall_boxnms_list ]) recall = np.mean([ precision_recall["recall"] for precision_recall in precision_recall_boxnms_list ]) self.writer.add_scalar(task + "-precision_boxnms", precision, n_iter) self.writer.add_scalar(task + "-recall_boxnms", recall, n_iter) print("-- [%s-%d-boxNMS] precision: %.4f, recall: %.4f" % (task, n_iter, precision, recall))
def addImg2tensorboard( self, img, labels_2D, semi, img_warp=None, labels_warp_2D=None, mask_warp_2D=None, semi_warp=None, mask_3D_flattened=None, task="training", ): """ # deprecated: add images to tensorboard :param img: :param labels_2D: :param semi: :param img_warp: :param labels_warp_2D: :param mask_warp_2D: :param semi_warp: :param mask_3D_flattened: :param task: :return: """ # print("add images to tensorboard") n_iter = self.n_iter semi_flat = flattenDetection(semi[0, :, :, :]) semi_warp_flat = flattenDetection(semi_warp[0, :, :, :]) thd = self.config["model"]["detection_threshold"] semi_thd = thd_img(semi_flat, thd=thd) semi_warp_thd = thd_img(semi_warp_flat, thd=thd) result_overlap = img_overlap(toNumpy(labels_2D[0, :, :, :]), toNumpy(semi_thd), toNumpy(img[0, :, :, :])) self.writer.add_image(task + "-detector_output_thd_overlay", result_overlap, n_iter) saveImg( result_overlap.transpose([1, 2, 0])[..., [2, 1, 0]] * 255, "test_0.png") # rgb to bgr * 255 result_overlap = img_overlap( toNumpy(labels_warp_2D[0, :, :, :]), toNumpy(semi_warp_thd), toNumpy(img_warp[0, :, :, :]), ) self.writer.add_image(task + "-warp_detector_output_thd_overlay", result_overlap, n_iter) saveImg( result_overlap.transpose([1, 2, 0])[..., [2, 1, 0]] * 255, "test_1.png") # rgb to bgr * 255 mask_overlap = img_overlap( toNumpy(1 - mask_warp_2D[0, :, :, :]) / 2, np.zeros_like(toNumpy(img_warp[0, :, :, :])), toNumpy(img_warp[0, :, :, :]), ) # writer.add_image(task + '_mask_valid_first_layer', mask_warp[0, :, :, :], n_iter) # writer.add_image(task + '_mask_valid_last_layer', mask_warp[-1, :, :, :], n_iter) ##### print to check # print("mask_2D shape: ", mask_warp_2D.shape) # print("mask_3D_flattened shape: ", mask_3D_flattened.shape) for i in range(self.batch_size): if i < 5: self.writer.add_image(task + "-mask_warp_origin", mask_warp_2D[i, :, :, :], n_iter) self.writer.add_image(task + "-mask_warp_3D_flattened", mask_3D_flattened[i, :, :], n_iter) # self.writer.add_image(task + '-mask_warp_origin-1', mask_warp_2D[1, :, :, :], n_iter) # self.writer.add_image(task + '-mask_warp_3D_flattened-1', mask_3D_flattened[1, :, :], n_iter) self.writer.add_image(task + "-mask_warp_overlay", mask_overlap, n_iter)
def train_val_sample(self, sample, n_iter=0, train=False): """ # deprecated: default train_val_sample :param sample: :param n_iter: :param train: :return: """ task = "train" if train else "val" tb_interval = self.config["tensorboard_interval"] losses = {} ## get the inputs # logging.info('get input img and label') img, labels_2D, mask_2D = ( sample["image"], sample["labels_2D"], sample["valid_mask"], ) # img, labels = img.to(self.device), labels_2D.to(self.device) # variables batch_size, H, W = img.shape[0], img.shape[2], img.shape[3] self.batch_size = batch_size # print("batch_size: ", batch_size) Hc = H // self.cell_size Wc = W // self.cell_size # warped images # img_warp, labels_warp_2D, mask_warp_2D = sample['warped_img'].to(self.device), \ # sample['warped_labels'].to(self.device), \ # sample['warped_valid_mask'].to(self.device) img_warp, labels_warp_2D, mask_warp_2D = ( sample["warped_img"], sample["warped_labels"], sample["warped_valid_mask"], ) # homographies # mat_H, mat_H_inv = \ # sample['homographies'].to(self.device), sample['inv_homographies'].to(self.device) mat_H, mat_H_inv = sample["homographies"], sample["inv_homographies"] # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize if train: # print("img: ", img.shape, ", img_warp: ", img_warp.shape) outs, outs_warp = ( self.net(img.to(self.device)), self.net(img_warp.to(self.device), subpixel=self.subpixel), ) semi, coarse_desc = outs[0], outs[1] semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1] else: with torch.no_grad(): outs, outs_warp = ( self.net(img.to(self.device)), self.net(img_warp.to(self.device), subpixel=self.subpixel), ) semi, coarse_desc = outs[0], outs[1] semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1] pass # detector loss ## get labels, masks, loss for detection labels3D_in_loss = self.getLabels(labels_2D, self.cell_size, device=self.device) mask_3D_flattened = self.getMasks(mask_2D, self.cell_size, device=self.device) loss_det = self.get_loss(semi, labels3D_in_loss, mask_3D_flattened, device=self.device) ## warping labels3D_in_loss = self.getLabels(labels_warp_2D, self.cell_size, device=self.device) mask_3D_flattened = self.getMasks(mask_warp_2D, self.cell_size, device=self.device) loss_det_warp = self.get_loss(semi_warp, labels3D_in_loss, mask_3D_flattened, device=self.device) mask_desc = mask_3D_flattened.unsqueeze(1) # print("mask_desc: ", mask_desc.shape) # print("mask_warp_2D: ", mask_warp_2D.shape) # descriptor loss # if self.desc_loss_type == 'dense': loss_desc, mask, positive_dist, negative_dist = self.descriptor_loss( coarse_desc, coarse_desc_warp, mat_H, mask_valid=mask_desc, device=self.device, **self.desc_params) loss = (loss_det + loss_det_warp + self.config["model"]["lambda_loss"] * loss_desc) if self.subpixel: # coarse to dense descriptor # work on warped level # dense_desc = interpolate_to_dense(coarse_desc_warp, cell_size=self.cell_size) # tensor [batch, 256, H, W] dense_map = flattenDetection(semi_warp) # tensor [batch, 1, H, W] # concat image and dense_desc concat_features = torch.cat((img_warp.to(self.device), dense_map), dim=1) # tensor [batch, n, H, W] # prediction # pred_heatmap = self.subpixNet(concat_features.to(self.device)) # tensor [batch, 1, H, W] pred_heatmap = outs_warp[2] # tensor [batch, 1, H, W] # print("pred_heatmap: ", pred_heatmap.shape) # add histogram here # tensor [batch, channels, H, W] # loss labels_warped_res = sample["warped_res"] # writer.add_histogram(task + '-' + 'warped_res', # labels_warped_res[0,...].clone().cpu().data.numpy().transpose(0,1).transpose(1,2).view(-1, 2), # n_iter) # from utils.losses import subpixel_loss subpix_loss = self.subpixel_loss_func( labels_warp_2D.to(self.device), labels_warped_res.to(self.device), pred_heatmap.to(self.device), patch_size=11, ) # print("subpix_loss: ", subpix_loss) # loss += subpix_loss # loss = subpix_loss # extract the patches from labels label_idx = labels_2D[...].nonzero() from utils.losses import extract_patches patch_size = 32 patches = extract_patches( label_idx.to(self.device), img_warp.to(self.device), patch_size=patch_size, ) # tensor [N, patch_size, patch_size] # patches = extract_patches(label_idx.to(device), labels_2D.to(device), patch_size=15) # tensor [N, patch_size, patch_size] print("patches: ", patches.shape) def label_to_points(labels_res, points): labels_res = labels_res.transpose(1, 2).transpose(2, 3).unsqueeze(1) points_res = labels_res[points[:, 0], points[:, 1], points[:, 2], points[:, 3], :] # tensor [N, 2] return points_res points_res = label_to_points(labels_warped_res, label_idx) num_patches_max = 500 # feed into the network pred_res = self.subnet(patches[:num_patches_max, ...].to( self.device)) # tensor [1, N, 2] # loss function def get_loss(points_res, pred_res): loss = points_res - pred_res loss = torch.norm(loss, p=2, dim=-1).mean() return loss loss = get_loss(points_res[:num_patches_max, ...].to(self.device), pred_res) losses.update({"subpix_loss": subpix_loss}) self.loss = loss losses.update({ "loss": loss, "loss_det": loss_det, "loss_det_warp": loss_det_warp, "loss_det": loss_det, "loss_det_warp": loss_det_warp, "positive_dist": positive_dist, "negative_dist": negative_dist, }) # print("losses: ", losses) if train: loss.backward() self.optimizer.step() self.addLosses2tensorboard(losses, task) if n_iter % tb_interval == 0 or task == "val": logging.info("current iteration: %d, tensorboard_interval: %d", n_iter, tb_interval) self.addImg2tensorboard( img, labels_2D, semi, img_warp, labels_warp_2D, mask_warp_2D, semi_warp, mask_3D_flattened=mask_3D_flattened, task=task, ) if self.subpixel: # print("only update subpixel_loss") self.add_single_image_to_tb(task, pred_heatmap, n_iter, name="subpixel_heatmap") self.printLosses(losses, task) # if n_iter % tb_interval == 0 or task == 'val': # print ("add nms") self.add2tensorboard_nms(img, labels_2D, semi, task=task, batch_size=batch_size) return loss.item()