def train(): start_time = time.time() for i_batch, sample_batched in enumerate(train_data, 1): model.train() batch = parse_batch(sample_batched, device) with autograd.detect_anomaly(): for des_train in range(0, cfg.TRAIN.DES): model.zero_grad() des_optim.zero_grad() endpoint = model(batch) _, _, desloss = (model.module.criterion(endpoint) if mgpu else model.criterion(endpoint)) desloss.backward() des_optim.step() for det_train in range(0, cfg.TRAIN.DET): model.zero_grad() det_optim.zero_grad() endpoint = model(batch) _, detloss, _ = (model.module.criterion(endpoint) if mgpu else model.criterion(endpoint)) detloss.backward() det_optim.step() Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg) Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg) # log if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0: elapsed = time.time() - start_time model.eval() with torch.no_grad(): eptr = model(parse_unsqueeze(train_data.dataset[0], device)) PLT, cur_detloss, cur_desloss = ( model.module.criterion(eptr) if mgpu else model.criterion(eptr)) PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(eptr["im1_lpdes"], eptr["im1_rpdes"]) PLTS["det_lr"] = det_optim.param_groups[0]["lr"] PLTS["des_lr"] = des_optim.param_groups[0]["lr"] if mgpu: mgpu_merge(PLTS) iteration = (epoch - 1) * len(train_data) + (i_batch - 1) writer_log(train_writer, PLT["scalar"], iteration) pstring = ( "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | " "sco {:07.05f} | pair {:05.03f} | des {:05.03f} |". format( epoch, i_batch, len(train_data) // cfg.TRAIN.BATCH_SIZE, elapsed / cfg.TRAIN.LOG_INTERVAL, PLTS["score_loss"], PLTS["pair_loss"], PLTS["hard_loss"], )) # eval log # parsed_valbatch = parse_unsqueeze(val_data.dataset[0], device) # ept = model(parsed_valbatch) ept = model(parse_unsqueeze(val_data.dataset[0], device)) PLT, _, _ = (model.module.criterion(ept) if mgpu else model.criterion(ept)) PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(ept["im1_lpdes"], ept["im1_rpdes"]) writer_log(test_writer, PLT["scalar"], iteration) print(f"{gct()} | {pstring}") start_time = time.time()
def train(): start_time = time.time() for i_batch, sample_batched in enumerate(train_data, 1): det.train() des.train() batch = parse_batch(sample_batched, device) with autograd.detect_anomaly(): for des_train in range(0, cfg.TRAIN.DES): des.zero_grad() des_optim.zero_grad() im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch score_maps, orint_maps = det(im1_data) im1_rawsc, im1_scale, im1_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) score_maps, orint_maps = det(im2_data) im2_rawsc, im2_scale, im2_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) im1_gtscale, im1_gtorin = gt_scale_orin( im2_scale, im2_orin, homo12, homo21) im2_gtscale, im2_gtorin = gt_scale_orin( im1_scale, im1_orin, homo21, homo12) im2_score = filter_border(im2_rawsc) im1w_score = warp(im2_score, homo12) im1_gtsc, im1_topkmask, im1_topkvalue = det.process( im1w_score) im1_score = filter_border(im1_rawsc) im2w_score = warp(im1_score, homo21) im2_gtsc, im2_topkmask, im2_topkvalue = det.process( im2w_score) im1_ppair, im1_limc, im1_rimcw = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_gtscale, im2_gtorin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_ppair, im2_limc, im2_rimcw = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_gtscale, im1_gtorin, im1_info, im1_raw, cfg.PATCH.SIZE, ) im1_lpatch, im1_rpatch = im1_ppair.chunk( chunks=2, dim=1) # each is (N, 32, 32) im2_lpatch, im2_rpatch = im2_ppair.chunk( chunks=2, dim=1) # each is (N, 32, 32) im1_lpatch = des.input_norm(im1_lpatch) im2_lpatch = des.input_norm(im2_lpatch) im1_rpatch = des.input_norm(im1_rpatch) im2_rpatch = des.input_norm(im2_rpatch) im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch) im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch) endpoint = { "im1_limc": im1_limc, "im1_rimcw": im1_rimcw, "im2_limc": im2_limc, "im2_rimcw": im2_rimcw, "im1_lpdes": im1_lpdes, "im1_rpdes": im1_rpdes, "im2_lpdes": im2_lpdes, "im2_rpdes": im2_rpdes, } desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) desloss.backward() des_optim.step() for det_train in range(0, cfg.TRAIN.DET): det.zero_grad() det_optim.zero_grad() im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch score_maps, orint_maps = det(im1_data) im1_rawsc, im1_scale, im1_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) score_maps, orint_maps = det(im2_data) im2_rawsc, im2_scale, im2_orin = handle_det_out( score_maps, orint_maps, det.scale_list, det.score_com_strength, det.scale_com_strength) im2_score = filter_border(im2_rawsc) im1w_score = warp(im2_score, homo12) im1_visiblemask = warp( im2_score.new_full(im2_score.size(), fill_value=1, requires_grad=True), homo12, ) im1_gtsc, im1_topkmask, im1_topkvalue = det.process( im1w_score) im1_score = filter_border(im1_rawsc) im2w_score = warp(im1_score, homo21) im2_visiblemask = warp( im2_score.new_full(im1_score.size(), fill_value=1, requires_grad=True), homo21, ) im2_gtsc, im2_topkmask, im2_topkvalue = det.process( im2w_score) im1_score = det.process(im1_rawsc)[0] im2_score = det.process(im2_rawsc)[0] im1_predpair, _, _ = pair( im1_topkmask, im1_topkvalue, im1_scale, im1_orin, im1_info, im1_raw, homo12, im2_scale, im2_orin, im2_info, im2_raw, cfg.PATCH.SIZE, ) im2_predpair, _, _ = pair( im2_topkmask, im2_topkvalue, im2_scale, im2_orin, im2_info, im2_raw, homo21, im1_scale, im1_orin, im1_info, im1_raw, cfg.PATCH.SIZE, ) # each is (N, 32, 32) im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk( chunks=2, dim=1) im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk( chunks=2, dim=1) im1_lpredpatch = des.input_norm(im1_lpredpatch) im2_lpredpatch = des.input_norm(im2_lpredpatch) im1_rpredpatch = des.input_norm(im1_rpredpatch) im2_rpredpatch = des.input_norm(im2_rpredpatch) im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des( im1_rpredpatch) im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des( im2_rpredpatch) endpoint = { "im1_score": im1_score, "im1_gtsc": im1_gtsc, "im1_visible": im1_visiblemask, "im2_score": im2_score, "im2_gtsc": im2_gtsc, "im2_visible": im2_visiblemask, "im1_lpreddes": im1_lpreddes, "im1_rpreddes": im1_rpreddes, "im2_lpreddes": im2_lpreddes, "im2_rpreddes": im2_rpreddes, } detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) detloss.backward() det_optim.step() Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg) Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg) # log if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0: elapsed = time.time() - start_time det.eval() des.eval() with torch.no_grad(): parsed_trainbatch = parse_unsqueeze( train_data.dataset[0], device) endpoint = get_all_endpoints(det, des, parsed_trainbatch) detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) PLT_SCALAR = {} PLT = {"scalar": PLT_SCALAR} PLT_SCALAR["pair_loss"] = detloss PLT_SCALAR["hard_loss"] = desloss PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(endpoint["im1_lpdes"], endpoint["im1_rpdes"]) PLTS["det_lr"] = det_optim.param_groups[0]["lr"] PLTS["des_lr"] = des_optim.param_groups[0]["lr"] if mgpu: mgpu_merge(PLTS) iteration = (epoch - 1) * len(train_data) + (i_batch - 1) writer_log(train_writer, PLT["scalar"], iteration) pstring = ( "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | " "pair {:05.03f} | des {:05.03f} |".format( epoch, i_batch, len(train_data) // cfg.TRAIN.BATCH_SIZE, elapsed / cfg.TRAIN.LOG_INTERVAL, PLTS["pair_loss"], PLTS["hard_loss"], )) # eval log parsed_valbatch = parse_unsqueeze(val_data.dataset[0], device) ept = get_all_endpoints(det, des, parsed_valbatch) detloss = (det.module.criterion(endpoint) if mgpu else det.criterion(endpoint)) desloss = (des.module.criterion(endpoint) if mgpu else des.criterion(endpoint)) PLT_SCALAR = {} PLT = {"scalar": PLT_SCALAR} PLT_SCALAR["pair_loss"] = detloss PLT_SCALAR["hard_loss"] = desloss PLTS = PLT["scalar"] PLTS["Accuracy"] = getAC(ept["im1_lpdes"], ept["im1_rpdes"]) writer_log(test_writer, PLT["scalar"], iteration) print(f"{gct()} | {pstring}") start_time = time.time()