Exemple #1
0
def main():
    parser = make_parser()
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()
    cfg.backbone_pretrained = False
    model = current_network.Net(cfg)
    model.eval()

    state_dict = mge.load(args.weight_file)
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]
    model.load_state_dict(state_dict)

    evaluator = DetEvaluator(model)

    ori_img = cv2.imread(args.image)
    image, im_info = DetEvaluator.process_inputs(
        ori_img.copy(),
        model.cfg.test_image_short_size,
        model.cfg.test_image_max_size,
    )
    pred_res = evaluator.predict(image=mge.tensor(image),
                                 im_info=mge.tensor(im_info))
    res_img = DetEvaluator.vis_det(
        ori_img,
        pred_res,
        is_show_label=True,
        classes=data_mapper[cfg.test_dataset["name"]].class_names,
    )
    cv2.imwrite("results.jpg", res_img)
Exemple #2
0
def worker(args):
    current_network = import_from_file(args.file)

    model = current_network.Net(current_network.Cfg())
    model.train()

    if dist.get_rank() == 0:
        logger.info(get_config_info(model.cfg))
        logger.info(repr(model))

    params_with_grad = []
    for name, param in model.named_parameters():
        if "bottom_up.conv1" in name and model.cfg.backbone_freeze_at >= 1:
            continue
        if "bottom_up.layer1" in name and model.cfg.backbone_freeze_at >= 2:
            continue
        params_with_grad.append(param)

    opt = SGD(
        params_with_grad,
        lr=model.cfg.basic_lr * args.batch_size * dist.get_world_size(),
        momentum=model.cfg.momentum,
        weight_decay=model.cfg.weight_decay,
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(params_with_grad,
                  callbacks=[dist.make_allreduce_cb("mean", dist.WORLD)])
    else:
        gm.attach(params_with_grad)

    if args.weight_file is not None:
        weights = mge.load(args.weight_file)
        model.backbone.bottom_up.load_state_dict(weights, strict=False)
    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters())  # sync parameters
        dist.bcast_list_(model.buffers())  # sync buffers

    if dist.get_rank() == 0:
        logger.info("Prepare dataset")
    train_loader = iter(
        build_dataloader(args.batch_size, args.dataset_dir, model.cfg))

    for epoch in range(model.cfg.max_epoch):
        train_one_epoch(model, train_loader, opt, gm, epoch, args)
        if dist.get_rank() == 0:
            save_path = "log-of-{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch)
            mge.save(
                {
                    "epoch": epoch,
                    "state_dict": model.state_dict()
                },
                save_path,
            )
            logger.info("dump weights to %s", save_path)
Exemple #3
0
def main():
    # pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval

    parser = make_parser()
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()

    if args.weight_file:
        args.start_epoch = args.end_epoch = -1
    else:
        if args.start_epoch == -1:
            args.start_epoch = cfg.max_epoch - 1
        if args.end_epoch == -1:
            args.end_epoch = args.start_epoch
        assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch

    for epoch_num in range(args.start_epoch, args.end_epoch + 1):
        if args.weight_file:
            weight_file = args.weight_file
        else:
            weight_file = "log-of-{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch_num)

        result_list = []
        if args.devices > 1:
            result_queue = Queue(2000)

            master_ip = "localhost"
            server = dist.Server()
            port = server.py_server_port
            procs = []
            for i in range(args.devices):
                proc = Process(
                    target=worker,
                    args=(
                        current_network,
                        weight_file,
                        args.dataset_dir,
                        result_queue,
                        master_ip,
                        port,
                        args.devices,
                        i,
                    ),
                )
                proc.start()
                procs.append(proc)

            num_imgs = dict(coco=5000, objects365=30000)

            for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
                result_list.append(result_queue.get())

            for p in procs:
                p.join()
        else:
            worker(current_network, weight_file, args.dataset_dir, result_list)

        all_results = DetEvaluator.format(result_list, cfg)
        json_path = "log-of-{}/epoch_{}.json".format(
            os.path.basename(args.file).split(".")[0], epoch_num)
        all_results = json.dumps(all_results)

        with open(json_path, "w") as fo:
            fo.write(all_results)
        logger.info("Save to %s finished, start evaluation!", json_path)

        eval_gt = COCO(
            os.path.join(args.dataset_dir, cfg.test_dataset["name"],
                         cfg.test_dataset["ann_file"]))
        eval_dt = eval_gt.loadRes(json_path)
        cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox")
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        metrics = [
            "AP",
            "[email protected]",
            "[email protected]",
            "APs",
            "APm",
            "APl",
            "AR@1",
            "AR@10",
            "AR@100",
            "ARs",
            "ARm",
            "ARl",
        ]
        logger.info("mmAP".center(32, "-"))
        for i, m in enumerate(metrics):
            logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i])
        logger.info("-" * 32)
Exemple #4
0
def main():
    # pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval

    parser = make_parser()
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()

    if args.weight_file:
        args.start_epoch = args.end_epoch = -1
    else:
        if args.start_epoch == -1:
            args.start_epoch = cfg.max_epoch - 1
        if args.end_epoch == -1:
            args.end_epoch = args.start_epoch
        assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch

    for epoch_num in range(args.start_epoch, args.end_epoch + 1):
        if args.weight_file:
            weight_file = args.weight_file
        else:
            weight_file = "log-of-{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch_num)

        if args.devices > 1:
            dist_worker = dist.launcher(n_gpus=args.devices)(worker)
            result_list = dist_worker(current_network, weight_file,
                                      args.dataset_dir)
            result_list = sum(result_list, [])
        else:
            result_list = worker(current_network, weight_file,
                                 args.dataset_dir)

        all_results = DetEvaluator.format(result_list, cfg)
        if args.weight_file:
            json_path = "{}_{}.json".format(
                os.path.basename(args.file).split(".")[0],
                os.path.basename(args.weight_file).split(".")[0],
            )
        else:
            json_path = "log-of-{}/epoch_{}.json".format(
                os.path.basename(args.file).split(".")[0], epoch_num)
        all_results = json.dumps(all_results)

        with open(json_path, "w") as fo:
            fo.write(all_results)
        logger.info("Save results to %s, start evaluation!", json_path)

        eval_gt = COCO(
            os.path.join(args.dataset_dir, cfg.test_dataset["name"],
                         cfg.test_dataset["ann_file"]))
        eval_dt = eval_gt.loadRes(json_path)
        cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox")
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        metrics = [
            "AP",
            "[email protected]",
            "[email protected]",
            "APs",
            "APm",
            "APl",
            "AR@1",
            "AR@10",
            "AR@100",
            "ARs",
            "ARm",
            "ARl",
        ]
        logger.info("mmAP".center(32, "-"))
        for i, m in enumerate(metrics):
            logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i])
        logger.info("-" * 32)