예제 #1
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    if args.enable_ce:
        random.seed(0)
        np.random.seed(0)

    cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
    cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))

    cfg.check_and_infer()
    print_info(pprint.pformat(cfg))

    name = cfg.TRAIN.MODEL_SAVE_DIR.replace("output/", "")
    store_config = {
        "model_name": cfg.MODEL.MODEL_NAME,
        "solver": cfg.SOLVER.OPTIMIZER,
        "lr": cfg.SOLVER.LR,
        "lr_policy": cfg.SOLVER.LR_POLICY,
    }
    wandb.init(project="rs_segmentation",
               name=name,
               config=store_config,
               dir=cfg.LOG_DIR,
               notes="")

    train(cfg)
예제 #2
0
파일: train.py 프로젝트: nepeplwu/PaddleSeg
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts is not None:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer(reset_dataset=True)
    print(pprint.pformat(cfg))
    train(cfg)
예제 #3
0
def main():
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    export_inference_model(args)
예제 #4
0
def main():
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    evaluate(cfg, **args.__dict__)
예제 #5
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)

    cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
    cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))

    cfg.check_and_infer()
    print_info(pprint.pformat(cfg))
    train(cfg)
예제 #6
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    cfg.check_and_infer(reset_dataset=True)
    logger.info(pprint.pformat(cfg))

    init_global_variable()
    check_train_dataset()

    init_global_variable()
    check_val_dataset()

    init_global_variable()
    check_test_dataset()

    inf_resize_value_check()
예제 #7
0
def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    cfg.check_and_infer()
    logger.info(pprint.pformat(cfg))

    init_global_variable()
    check_train_dataset()

    init_global_variable()
    check_val_dataset()

    init_global_variable()
    check_test_dataset()

    inf_resize_value_check()

    print("\nDetailed error information can be viewed in detail.log file.")
예제 #8
0
파일: vis.py 프로젝트: stevephone/PaddleSeg
                # BGR->RGB
                img = cv2.imread(
                    os.path.join(cfg.DATASET.DATA_DIR,
                                 img_names[i]))[..., ::-1]
                log_writer.add_image("Images/{}".format(img_names[i]),
                                     img,
                                     epoch,
                                     dataformats='HWC')
                #add ground truth (label) images
                if grt is not None:
                    log_writer.add_image("Label/{}".format(img_names[i]),
                                         grt[..., ::-1],
                                         epoch,
                                         dataformats='HWC')

        # If in local_test mode, only visualize 5 images just for testing
        # procedure
        if local_test and img_cnt >= 5:
            break


if __name__ == '__main__':
    args = parse_args()
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    if args.opts:
        cfg.update_from_list(args.opts)
    cfg.check_and_infer()
    print(pprint.pformat(cfg))
    visualize(cfg, **args.__dict__)