Exemplo n.º 1
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    if args.enable_ce:
        random.seed(0)
        np.random.seed(0)

    cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
    cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))

    cfg.check_and_infer()
    print_info(pprint.pformat(cfg))

    name = cfg.TRAIN.MODEL_SAVE_DIR.replace("output/", "")
    store_config = {
        "model_name": cfg.MODEL.MODEL_NAME,
        "solver": cfg.SOLVER.OPTIMIZER,
        "lr": cfg.SOLVER.LR,
        "lr_policy": cfg.SOLVER.LR_POLICY,
    }
    wandb.init(project="rs_segmentation",
               name=name,
               config=store_config,
               dir=cfg.LOG_DIR,
               notes="")

    train(cfg)
Exemplo n.º 2
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts is not None:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer(reset_dataset=True)
    print(pprint.pformat(cfg))
    train(cfg)
Exemplo n.º 3
0
def main():
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    export_inference_model(args)
Exemplo n.º 4
0
def main():
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    evaluate(cfg, **args.__dict__)
Exemplo n.º 5
0
def parse_args():
    parser = argparse.ArgumentParser(description="数据预处理")
    parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
    parser.add_argument("opts", nargs=argparse.REMAINDER)
    args = parser.parse_args()

    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
Exemplo n.º 6
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)

    cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
    cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))

    cfg.check_and_infer()
    print_info(pprint.pformat(cfg))
    train(cfg)
Exemplo n.º 7
0
def parse_args():
    parser = argparse.ArgumentParser(description="预测")
    parser.add_argument("-c", "--cfg_file", type=str, help="配置文件路径")
    parser.add_argument("--use_gpu",
                        action="store_true",
                        default=False,
                        help="使用GPU推理")
    parser.add_argument("opts", nargs=argparse.REMAINDER)
    args = parser.parse_args()

    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    if args.use_gpu:  # 命令行参数只能从false改成true,不能声明false
        cfg.TRAIN.USE_GPU = True

    cfg.set_immutable(True)
Exemplo n.º 8
0
                # BGR->RGB
                img = cv2.imread(
                    os.path.join(cfg.DATASET.DATA_DIR,
                                 img_names[i]))[..., ::-1]
                log_writer.add_image("Images/{}".format(img_names[i]),
                                     img,
                                     epoch,
                                     dataformats='HWC')
                #add ground truth (label) images
                if grt is not None:
                    log_writer.add_image("Label/{}".format(img_names[i]),
                                         grt[..., ::-1],
                                         epoch,
                                         dataformats='HWC')

        # If in local_test mode, only visualize 5 images just for testing
        # procedure
        if local_test and img_cnt >= 5:
            break


if __name__ == '__main__':
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    visualize(cfg, **args.__dict__)