コード例 #1
0
ファイル: train.py プロジェクト: iamwangyabin/FDLNet
    def train():
        start_time = time.time()
        for i_batch, sample_batched in enumerate(train_data, 1):
            model.train()
            batch = parse_batch(sample_batched, device)
            with autograd.detect_anomaly():
                for des_train in range(0, cfg.TRAIN.DES):
                    model.zero_grad()
                    des_optim.zero_grad()
                    endpoint = model(batch)
                    _, _, desloss = (model.module.criterion(endpoint)
                                     if mgpu else model.criterion(endpoint))
                    desloss.backward()
                    des_optim.step()
                for det_train in range(0, cfg.TRAIN.DET):
                    model.zero_grad()
                    det_optim.zero_grad()
                    endpoint = model(batch)
                    _, detloss, _ = (model.module.criterion(endpoint)
                                     if mgpu else model.criterion(endpoint))
                    detloss.backward()
                    det_optim.step()

            Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg)
            Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg)

            # log
            if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0:
                elapsed = time.time() - start_time
                model.eval()
                with torch.no_grad():
                    eptr = model(parse_unsqueeze(train_data.dataset[0],
                                                 device))
                    PLT, cur_detloss, cur_desloss = (
                        model.module.criterion(eptr)
                        if mgpu else model.criterion(eptr))

                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(eptr["im1_lpdes"],
                                             eptr["im1_rpdes"])
                    PLTS["det_lr"] = det_optim.param_groups[0]["lr"]
                    PLTS["des_lr"] = des_optim.param_groups[0]["lr"]
                    if mgpu:
                        mgpu_merge(PLTS)
                    iteration = (epoch - 1) * len(train_data) + (i_batch - 1)
                    writer_log(train_writer, PLT["scalar"], iteration)

                    pstring = (
                        "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | "
                        "sco {:07.05f} | pair {:05.03f} | des {:05.03f} |".
                        format(
                            epoch,
                            i_batch,
                            len(train_data) // cfg.TRAIN.BATCH_SIZE,
                            elapsed / cfg.TRAIN.LOG_INTERVAL,
                            PLTS["score_loss"],
                            PLTS["pair_loss"],
                            PLTS["hard_loss"],
                        ))

                    # eval log
                    # parsed_valbatch = parse_unsqueeze(val_data.dataset[0], device)
                    # ept = model(parsed_valbatch)
                    ept = model(parse_unsqueeze(val_data.dataset[0], device))
                    PLT, _, _ = (model.module.criterion(ept)
                                 if mgpu else model.criterion(ept))
                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(ept["im1_lpdes"],
                                             ept["im1_rpdes"])
                    writer_log(test_writer, PLT["scalar"], iteration)

                    print(f"{gct()} | {pstring}")
                    start_time = time.time()
コード例 #2
0
    def train():
        start_time = time.time()
        for i_batch, sample_batched in enumerate(train_data, 1):
            det.train()
            des.train()
            batch = parse_batch(sample_batched, device)
            with autograd.detect_anomaly():
                for des_train in range(0, cfg.TRAIN.DES):
                    des.zero_grad()
                    des_optim.zero_grad()
                    im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch
                    score_maps, orint_maps = det(im1_data)
                    im1_rawsc, im1_scale, im1_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    score_maps, orint_maps = det(im2_data)
                    im2_rawsc, im2_scale, im2_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    im1_gtscale, im1_gtorin = gt_scale_orin(
                        im2_scale, im2_orin, homo12, homo21)
                    im2_gtscale, im2_gtorin = gt_scale_orin(
                        im1_scale, im1_orin, homo21, homo12)
                    im2_score = filter_border(im2_rawsc)
                    im1w_score = warp(im2_score, homo12)
                    im1_gtsc, im1_topkmask, im1_topkvalue = det.process(
                        im1w_score)
                    im1_score = filter_border(im1_rawsc)
                    im2w_score = warp(im1_score, homo21)
                    im2_gtsc, im2_topkmask, im2_topkvalue = det.process(
                        im2w_score)
                    im1_ppair, im1_limc, im1_rimcw = pair(
                        im1_topkmask,
                        im1_topkvalue,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        homo12,
                        im2_gtscale,
                        im2_gtorin,
                        im2_info,
                        im2_raw,
                        cfg.PATCH.SIZE,
                    )
                    im2_ppair, im2_limc, im2_rimcw = pair(
                        im2_topkmask,
                        im2_topkvalue,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        homo21,
                        im1_gtscale,
                        im1_gtorin,
                        im1_info,
                        im1_raw,
                        cfg.PATCH.SIZE,
                    )
                    im1_lpatch, im1_rpatch = im1_ppair.chunk(
                        chunks=2, dim=1)  # each is (N, 32, 32)
                    im2_lpatch, im2_rpatch = im2_ppair.chunk(
                        chunks=2, dim=1)  # each is (N, 32, 32)
                    im1_lpatch = des.input_norm(im1_lpatch)
                    im2_lpatch = des.input_norm(im2_lpatch)
                    im1_rpatch = des.input_norm(im1_rpatch)
                    im2_rpatch = des.input_norm(im2_rpatch)
                    im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch)
                    im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch)
                    endpoint = {
                        "im1_limc": im1_limc,
                        "im1_rimcw": im1_rimcw,
                        "im2_limc": im2_limc,
                        "im2_rimcw": im2_rimcw,
                        "im1_lpdes": im1_lpdes,
                        "im1_rpdes": im1_rpdes,
                        "im2_lpdes": im2_lpdes,
                        "im2_rpdes": im2_rpdes,
                    }

                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    desloss.backward()
                    des_optim.step()
                for det_train in range(0, cfg.TRAIN.DET):
                    det.zero_grad()
                    det_optim.zero_grad()
                    im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch

                    score_maps, orint_maps = det(im1_data)
                    im1_rawsc, im1_scale, im1_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    score_maps, orint_maps = det(im2_data)
                    im2_rawsc, im2_scale, im2_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)

                    im2_score = filter_border(im2_rawsc)
                    im1w_score = warp(im2_score, homo12)
                    im1_visiblemask = warp(
                        im2_score.new_full(im2_score.size(),
                                           fill_value=1,
                                           requires_grad=True),
                        homo12,
                    )
                    im1_gtsc, im1_topkmask, im1_topkvalue = det.process(
                        im1w_score)

                    im1_score = filter_border(im1_rawsc)
                    im2w_score = warp(im1_score, homo21)
                    im2_visiblemask = warp(
                        im2_score.new_full(im1_score.size(),
                                           fill_value=1,
                                           requires_grad=True),
                        homo21,
                    )
                    im2_gtsc, im2_topkmask, im2_topkvalue = det.process(
                        im2w_score)

                    im1_score = det.process(im1_rawsc)[0]
                    im2_score = det.process(im2_rawsc)[0]
                    im1_predpair, _, _ = pair(
                        im1_topkmask,
                        im1_topkvalue,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        homo12,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        cfg.PATCH.SIZE,
                    )
                    im2_predpair, _, _ = pair(
                        im2_topkmask,
                        im2_topkvalue,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        homo21,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        cfg.PATCH.SIZE,
                    )
                    # each is (N, 32, 32)
                    im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(
                        chunks=2, dim=1)
                    im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(
                        chunks=2, dim=1)
                    im1_lpredpatch = des.input_norm(im1_lpredpatch)
                    im2_lpredpatch = des.input_norm(im2_lpredpatch)
                    im1_rpredpatch = des.input_norm(im1_rpredpatch)
                    im2_rpredpatch = des.input_norm(im2_rpredpatch)
                    im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des(
                        im1_rpredpatch)
                    im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des(
                        im2_rpredpatch)
                    endpoint = {
                        "im1_score": im1_score,
                        "im1_gtsc": im1_gtsc,
                        "im1_visible": im1_visiblemask,
                        "im2_score": im2_score,
                        "im2_gtsc": im2_gtsc,
                        "im2_visible": im2_visiblemask,
                        "im1_lpreddes": im1_lpreddes,
                        "im1_rpreddes": im1_rpreddes,
                        "im2_lpreddes": im2_lpreddes,
                        "im2_rpreddes": im2_rpreddes,
                    }

                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    detloss.backward()
                    det_optim.step()

            Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg)
            Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg)

            # log
            if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0:
                elapsed = time.time() - start_time
                det.eval()
                des.eval()
                with torch.no_grad():
                    parsed_trainbatch = parse_unsqueeze(
                        train_data.dataset[0], device)
                    endpoint = get_all_endpoints(det, des, parsed_trainbatch)
                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    PLT_SCALAR = {}
                    PLT = {"scalar": PLT_SCALAR}
                    PLT_SCALAR["pair_loss"] = detloss
                    PLT_SCALAR["hard_loss"] = desloss
                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(endpoint["im1_lpdes"],
                                             endpoint["im1_rpdes"])
                    PLTS["det_lr"] = det_optim.param_groups[0]["lr"]
                    PLTS["des_lr"] = des_optim.param_groups[0]["lr"]
                    if mgpu:
                        mgpu_merge(PLTS)
                    iteration = (epoch - 1) * len(train_data) + (i_batch - 1)
                    writer_log(train_writer, PLT["scalar"], iteration)

                    pstring = (
                        "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | "
                        "pair {:05.03f} | des {:05.03f} |".format(
                            epoch,
                            i_batch,
                            len(train_data) // cfg.TRAIN.BATCH_SIZE,
                            elapsed / cfg.TRAIN.LOG_INTERVAL,
                            PLTS["pair_loss"],
                            PLTS["hard_loss"],
                        ))
                    # eval log
                    parsed_valbatch = parse_unsqueeze(val_data.dataset[0],
                                                      device)
                    ept = get_all_endpoints(det, des, parsed_valbatch)
                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    PLT_SCALAR = {}
                    PLT = {"scalar": PLT_SCALAR}
                    PLT_SCALAR["pair_loss"] = detloss
                    PLT_SCALAR["hard_loss"] = desloss
                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(ept["im1_lpdes"],
                                             ept["im1_rpdes"])
                    writer_log(test_writer, PLT["scalar"], iteration)
                    print(f"{gct()} | {pstring}")
                    start_time = time.time()