Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f", "--file", default="net.py", type=str, help="net description file"
    )
    parser.add_argument(
        "-w", "--weight_file", default=None, type=str, help="weights file",
    )
    parser.add_argument("-i", "--image", type=str)
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()
    cfg.backbone_pretrained = False
    model = current_network.Net(cfg)
    model.eval()

    state_dict = mge.load(args.weight_file)
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]
    model.load_state_dict(state_dict)

    img = cv2.imread(args.image)
    pred = inference(img, model)
    cv2.imwrite("results.jpg", pred)
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f", "--file", default="net.py", type=str, help="net description file"
    )
    parser.add_argument(
        "-w", "--weight_file", default=None, type=str, help="weights file",
    )
    parser.add_argument(
        "-n", "--devices", default=1, type=int, help="total number of gpus for testing",
    )
    parser.add_argument(
        "-d", "--dataset_dir", default="/data/datasets", type=str,
    )
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()

    result_list = []
    if args.devices > 1:
        result_queue = Queue(500)

        master_ip = "localhost"
        server = dist.Server()
        port = server.py_server_port

        procs = []
        for i in range(args.devices):
            proc = Process(
                target=worker,
                args=(
                    current_network,
                    args.weight_file,
                    args.dataset_dir,
                    result_queue,
                    master_ip,
                    port,
                    args.devices,
                    i,
                ),
            )
            proc.start()
            procs.append(proc)

        num_imgs = dict(VOC2012=1449, Cityscapes=500)

        for _ in tqdm(range(num_imgs[cfg.dataset])):
            result_list.append(result_queue.get())

        for p in procs:
            p.join()
    else:
        worker(current_network, args.weight_file, args.dataset_dir, result_list)

    if cfg.val_save_path is not None:
        save_results(result_list, cfg.val_save_path, cfg)
    logger.info("Start evaluation!")
    compute_metric(result_list, cfg)
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f",
                        "--file",
                        default="net.py",
                        type=str,
                        help="net description file")
    parser.add_argument(
        "-w",
        "--weight_file",
        default=None,
        type=str,
        help="weights file",
    )
    parser.add_argument(
        "-n",
        "--devices",
        default=1,
        type=int,
        help="total number of gpus for testing",
    )
    parser.add_argument(
        "-d",
        "--dataset_dir",
        default="/data/datasets",
        type=str,
    )
    args = parser.parse_args()

    current_network = import_from_file(args.file)
    cfg = current_network.Cfg()

    if args.devices > 1:
        dist_worker = dist.launcher(n_gpus=args.devices)(worker)
        result_list = dist_worker(current_network, args.weight_file,
                                  args.dataset_dir)
        result_list = sum(result_list, [])
    else:
        result_list = worker(current_network, args.weight_file,
                             args.dataset_dir)

    if cfg.val_save_path is not None:
        save_results(result_list, cfg.val_save_path, cfg)
    logger.info("Start evaluation!")
    compute_metric(result_list, cfg)
Beispiel #4
0
def worker(args):
    current_network = import_from_file(args.file)

    model = current_network.Net(current_network.Cfg())
    model.train()

    if dist.get_rank() == 0:
        logger.info(get_config_info(model.cfg))
        logger.info(repr(model))

    backbone_params = []
    head_params = []
    for name, param in model.named_parameters():
        if "backbone" in name:
            backbone_params.append(param)
        else:
            head_params.append(param)

    opt = SGD(
        [
            {
                "params": backbone_params,
                "lr": model.cfg.learning_rate * 0.1
            },
            {
                "params": head_params
            },
        ],
        lr=model.cfg.learning_rate,
        momentum=model.cfg.momentum,
        weight_decay=model.cfg.weight_decay * dist.get_world_size(),
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(model.parameters(),
                  callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)])
    else:
        gm.attach(model.parameters())

    cur_epoch = 0
    if args.resume is not None:
        pretrained = mge.load(args.resume)
        cur_epoch = pretrained["epoch"] + 1
        model.load_state_dict(pretrained["state_dict"])
        opt.load_state_dict(pretrained["opt"])
        if dist.get_rank() == 0:
            logger.info("load success: epoch %d", cur_epoch)

    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters(), dist.WORLD)  # sync parameters

    if dist.get_rank() == 0:
        logger.info("Prepare dataset")
    train_loader = iter(
        build_dataloader(model.cfg.batch_size, args.dataset_dir, model.cfg))

    for epoch in range(cur_epoch, model.cfg.max_epoch):
        train_one_epoch(model, train_loader, opt, gm, epoch)
        if dist.get_rank() == 0:
            save_path = "log-of-{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch)
            mge.save(
                {
                    "epoch": epoch,
                    "state_dict": model.state_dict(),
                    "opt": opt.state_dict()
                }, save_path)
            logger.info("dump weights to %s", save_path)