def gtscore(self, right_score, homolr): im2_score = right_score im2_score = filter_border(im2_score) # warp im2_score to im1w_score and calculate visible_mask im1w_score = warp(im2_score, homolr) im1visible_mask = warp( im2_score.new_full(im2_score.size(), fill_value=1, requires_grad=True), homolr, ) im1gt_score, topk_mask, topk_value = self.det.process(im1w_score) return im1gt_score, topk_mask, topk_value, im1visible_mask
def process(self, im1w_score): """ nms(n), topk(t), gaussian kernel(g) operation :param im1w_score: warped score map :return: processed score map, topk mask, topk value """ im1w_score = filter_border(im1w_score) # apply nms to im1w_score nms_mask = nms(im1w_score, thresh=self.NMS_THRESH, ksize=self.NMS_KSIZE) im1w_score = im1w_score * nms_mask topk_value = im1w_score # apply topk to im1w_score topk_mask = topk_map(im1w_score, self.TOPK) im1w_score = topk_mask.to(torch.float) * im1w_score # apply gaussian kernel to im1w_score psf = get_gauss_filter_weight( self.GAUSSIAN_KSIZE, self.GAUSSIAN_SIGMA)[None, None, :, :].to(im1w_score.device) # psf = im1w_score.new_tensor( # get_gauss_filter_weight(self.GAUSSIAN_KSIZE, self.GAUSSIAN_SIGMA)[ # None, None, :, : # ] # ) im1w_score = F.conv2d( input=im1w_score.permute(0, 3, 1, 2), weight=psf, stride=1, padding=self.GAUSSIAN_KSIZE // 2, ).permute(0, 2, 3, 1) # (B, H, W, 1) """ apply tf.clamp to make sure all value in im1w_score isn't greater than 1 but this won't happend in correct way """ im1w_score = im1w_score.clamp(min=0.0, max=1.0) return im1w_score, topk_mask, topk_value
def process(self, im1w_score): """ nms(n), topk(t), gaussian kernel(g) operation :param im1w_score: warped score map :return: processed score map, topk mask, topk value """ im1w_score = filter_border(im1w_score) # apply nms to im1w_score #nms:non-maximum suppression非最大值抑制 nms_mask = nms(im1w_score, thresh=self.NMS_THRESH, ksize=self.NMS_KSIZE) im1w_score = im1w_score * nms_mask topk_value = im1w_score # apply topk to im1w_score topk_mask = topk_map(im1w_score, self.TOPK) im1w_score = topk_mask.to(torch.float) * im1w_score # apply gaussian kernel to im1w_score psf = im1w_score.new_tensor( get_gauss_filter_weight(self.GAUSSIAN_KSIZE, self.GAUSSIAN_SIGMA)[None, None, :, :]) im1w_score = F.conv2d( input=im1w_score.permute(0, 3, 1, 2), weight=psf, stride=1, padding=self.GAUSSIAN_KSIZE // 2, ).permute(0, 2, 3, 1) # (B, H, W, 1) #python pytorch.permute函数用于变换参数的维数,比如这里的F.conv2d的参数有4个,现在将里面参数的顺序调整为0 2 3 1 """ apply tf.clamp to make sure all value in im1w_score isn't greater than 1 but this won't happend in correct way """ im1w_score = im1w_score.clamp(min=0.0, max=1.0) #python torch clamp函数,用于将参数的取值范围限制在某个范围内,相当于把数值夹在某个区间上 #如果clamp函数里面参数小于等于最小值,那么这个参数设置为最小值;如果这个参数大于最大值,那么这个参数设置为最大值 #如果clamp函数里面参数介于最大值与最小值之间,那么就取这个参数 本身的值 return im1w_score, topk_mask, topk_value
def train(): start_time = time.time() for i_batch, sample_batched in enumerate(train_data, 1): det.train() des.train() batch = parse_batch(sample_batched, device) with autograd.detect_anomaly(): for des_train in range(0, cfg.TRAIN.DES): des.zero_grad() des_optim.zero_grad() im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch score_maps, orint_maps = det(im1_data) im1_rawsc, im1_scale, im1_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) score_maps, orint_maps = det(im2_data) im2_rawsc, im2_scale, im2_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) im1_gtscale, im1_gtorin = gt_scale_orin( im2_scale, im2_orin, homo12, homo21) im2_gtscale, im2_gtorin = gt_scale_orin( im1_scale, im1_orin, homo21, homo12) im2_score = filter_border(im2_rawsc) im1w_score = warp(im2_score, homo12) im1_gtsc, im1_topkmask, im1_topkvalue = det.process( im1w_score) im1_score = filter_border(im1_rawsc) im2w_score = warp(im1_score, homo21) im2_gtsc, im2_topkmask, im2_topkvalue = det.process( im2w_score) im1_ppair, im1_limc, im1_rimcw = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_gtscale, im2_gtorin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_ppair, im2_limc, im2_rimcw = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_gtscale, im1_gtorin, im1_info, im1_raw, cfg.PATCH.SIZE, ) im1_lpatch, im1_rpatch = im1_ppair.chunk( chunks=2, dim=1) # each is (N, 32, 32) im2_lpatch, im2_rpatch = im2_ppair.chunk( chunks=2, dim=1) # each is (N, 32, 32) im1_lpatch = des.input_norm(im1_lpatch) im2_lpatch = des.input_norm(im2_lpatch) im1_rpatch = des.input_norm(im1_rpatch) im2_rpatch = des.input_norm(im2_rpatch) im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch) im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch) endpoint = { "im1_limc": im1_limc, "im1_rimcw": im1_rimcw, "im2_limc": im2_limc, "im2_rimcw": im2_rimcw, "im1_lpdes": im1_lpdes, "im1_rpdes": im1_rpdes, "im2_lpdes": im2_lpdes, "im2_rpdes": im2_rpdes, } desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) desloss.backward() des_optim.step() for det_train in range(0, cfg.TRAIN.DET): det.zero_grad() det_optim.zero_grad() im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch score_maps, orint_maps = det(im1_data) im1_rawsc, im1_scale, im1_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) score_maps, orint_maps = det(im2_data) im2_rawsc, im2_scale, im2_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) im2_score = filter_border(im2_rawsc) im1w_score = warp(im2_score, homo12) im1_visiblemask = warp( im2_score.new_full(im2_score.size(), fill_value=1, requires_grad=True), homo12, ) im1_gtsc, im1_topkmask, im1_topkvalue = det.process( im1w_score) im1_score = filter_border(im1_rawsc) im2w_score = warp(im1_score, homo21) im2_visiblemask = warp( im2_score.new_full(im1_score.size(), fill_value=1, requires_grad=True), homo21, ) im2_gtsc, im2_topkmask, im2_topkvalue = det.process( im2w_score) im1_score = det.process(im1_rawsc)[0] im2_score = det.process(im2_rawsc)[0] im1_predpair, _, _ = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_scale, im2_orin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_predpair, _, _ = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_scale, im1_orin, im1_info, im1_raw, cfg.PATCH.SIZE, ) # each is (N, 32, 32) im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk( chunks=2, dim=1) im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk( chunks=2, dim=1) im1_lpredpatch = des.input_norm(im1_lpredpatch) im2_lpredpatch = des.input_norm(im2_lpredpatch) im1_rpredpatch = des.input_norm(im1_rpredpatch) im2_rpredpatch = des.input_norm(im2_rpredpatch) im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des( im1_rpredpatch) im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des( im2_rpredpatch) endpoint = { "im1_score": im1_score, "im1_gtsc": im1_gtsc, "im1_visible": im1_visiblemask, "im2_score": im2_score, "im2_gtsc": im2_gtsc, "im2_visible": im2_visiblemask, "im1_lpreddes": im1_lpreddes, "im1_rpreddes": im1_rpreddes, "im2_lpreddes": im2_lpreddes, "im2_rpreddes": im2_rpreddes, } detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) detloss.backward() det_optim.step() Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg) Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg) # log if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0: elapsed = time.time() - start_time det.eval() des.eval() with torch.no_grad(): parsed_trainbatch = parse_unsqueeze( train_data.dataset[0], device) endpoint = get_all_endpoints(det, des, parsed_trainbatch) detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) PLT_SCALAR = {} PLT = {"scalar": PLT_SCALAR} PLT_SCALAR["pair_loss"] = detloss PLT_SCALAR["hard_loss"] = desloss PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(endpoint["im1_lpdes"], endpoint["im1_rpdes"]) PLTS["det_lr"] = det_optim.param_groups[0]["lr"] PLTS["des_lr"] = des_optim.param_groups[0]["lr"] if mgpu: mgpu_merge(PLTS) iteration = (epoch - 1) * len(train_data) + (i_batch - 1) writer_log(train_writer, PLT["scalar"], iteration) pstring = ( "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | " "pair {:05.03f} | des {:05.03f} |".format( epoch, i_batch, len(train_data) // cfg.TRAIN.BATCH_SIZE, elapsed / cfg.TRAIN.LOG_INTERVAL, PLTS["pair_loss"], PLTS["hard_loss"], )) # eval log parsed_valbatch = parse_unsqueeze(val_data.dataset[0], device) ept = get_all_endpoints(det, des, parsed_valbatch) detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) PLT_SCALAR = {} PLT = {"scalar": PLT_SCALAR} PLT_SCALAR["pair_loss"] = detloss PLT_SCALAR["hard_loss"] = desloss PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(ept["im1_lpdes"], ept["im1_rpdes"]) writer_log(test_writer, PLT["scalar"], iteration) print(f"{gct()} | {pstring}") start_time = time.time()
def get_all_endpoints(det, des, batch): im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch score_maps, orint_maps = det(im1_data) im1_rawsc, im1_scale, im1_orin = handle_det_out(score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) score_maps, orint_maps = det(im2_data) im2_rawsc, im2_scale, im2_orin = handle_det_out(score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) im1_gtscale, im1_gtorin = gt_scale_orin(im2_scale, im2_orin, homo12, homo21) im2_gtscale, im2_gtorin = gt_scale_orin(im1_scale, im1_orin, homo21, homo12) im2_score = filter_border(im2_rawsc) im1w_score = warp(im2_score, homo12) im1_visiblemask = warp( im2_score.new_full(im2_score.size(), fill_value=1, requires_grad=True), homo12, ) im1_gtsc, im1_topkmask, im1_topkvalue = det.process(im1w_score) im1_score = filter_border(im1_rawsc) im2w_score = warp(im1_score, homo21) im2_visiblemask = warp( im2_score.new_full(im1_score.size(), fill_value=1, requires_grad=True), homo21, ) im2_gtsc, im2_topkmask, im2_topkvalue = det.process(im2w_score) im1_score = det.process(im1_rawsc)[0] im2_score = det.process(im2_rawsc)[0] im1_ppair, im1_limc, im1_rimcw = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_gtscale, im2_gtorin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_ppair, im2_limc, im2_rimcw = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_gtscale, im1_gtorin, im1_info, im1_raw, cfg.PATCH.SIZE, ) im1_lpatch, im1_rpatch = im1_ppair.chunk(chunks=2, dim=1) # each is (N, 32, 32) im2_lpatch, im2_rpatch = im2_ppair.chunk(chunks=2, dim=1) # each is (N, 32, 32) im1_lpatch = des.input_norm(im1_lpatch) im2_lpatch = des.input_norm(im2_lpatch) im1_rpatch = des.input_norm(im1_rpatch) im2_rpatch = des.input_norm(im2_rpatch) im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch) im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch) im1_predpair, _, _ = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_scale, im2_orin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_predpair, _, _ = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_scale, im1_orin, im1_info, im1_raw, cfg.PATCH.SIZE, ) # each is (N, 32, 32) im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(chunks=2, dim=1) im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(chunks=2, dim=1) im1_lpredpatch = des.input_norm(im1_lpredpatch) im2_lpredpatch = des.input_norm(im2_lpredpatch) im1_rpredpatch = des.input_norm(im1_rpredpatch) im2_rpredpatch = des.input_norm(im2_rpredpatch) im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des(im1_rpredpatch) im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des(im2_rpredpatch) endpoint = { "im1_score": im1_score, "im1_gtsc": im1_gtsc, "im1_visible": im1_visiblemask, "im2_score": im2_score, "im2_gtsc": im2_gtsc, "im2_visible": im2_visiblemask, "im1_lpreddes": im1_lpreddes, "im1_rpreddes": im1_rpreddes, "im2_lpreddes": im2_lpreddes, "im2_rpreddes": im2_rpreddes, "im1_limc": im1_limc, # "im1_rimcw": im1_rimcw, # "im2_limc": im2_limc, # "im2_rimcw": im2_rimcw, # "im1_lpdes": im1_lpdes, # "im1_rpdes": im1_rpdes, # "im2_lpdes": im2_lpdes, # "im2_rpdes": im2_rpdes, # } return endpoint