예제 #1
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tensormask_config(cfg)
    # set config file
    cfg.merge_from_file(
        "configs/BDD00K-InstanceSegmentation/tensormask_r101_3x_single_scale_bs8_bdd100k.yaml"
    )
    cfg.DATASETS.TRAIN = ("bdd100k_train", )
    cfg.DATASETS.TEST = ("bdd100k_test", )
    # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")  # Let training initialize from model zoo

    cfg.OUTPUT_DIR = './tensormask_r101_3x_single_scale_bs16_bdd100k'
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    # new added solver arguments
    cfg.SOLVER.CHECKPOINT_PERIOD = 500
    cfg.TEST.EVAL_PERIOD = 500
    # end of new arguments

    # cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
예제 #2
0
def load_tensormask_model(model_path, cfg_path, device):
    """
    Load the pretrained TensorMask model states and prepare the model for image segmentation.

    Paramters
    ---------
    model_path: str
        Path to the pretrained model states binary file.
    cfg_path: str
        Path to the model's Configuration file.  Located in the .configs folder.
    device: torch.device
        Device to load the model on.

    Returns
    -------
    model: TensorMask
        Model with the loaded pretrained states.
    """
    # set up model config
    cfg = get_cfg()
    tensormask.add_tensormask_config(cfg)
    cfg.merge_from_file(cfg_path)
    model = build_model(cfg)

    # load the model weights
    DetectionCheckpointer(model).load(model_path)
    model.eval()

    return model
예제 #3
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tensormask_config(cfg)

    cfg.merge_from_file(f'tensormask/configs/{args.config_file}.yaml')
    cfg.merge_from_list(args.opts)

    if args.log_dir:
        cfg.OUTPUT_DIR_BASE = args.log_dir
    if args.data_dir:
        cfg.DATASETS.TRAIN = (args.data_dir,)

    cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print('device:', cfg.MODEL.DEVICE)


    register_datasets(cfg.DATASETS.TRAIN)
    register_datasets(cfg.DATASETS.TEST)

    # setup up logging directory
    cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR_BASE, args.config_file)
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
예제 #4
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tensormask_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
예제 #5
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tensormask_config(cfg)
    cfg.merge_from_file(
        "/root/detectron2/projects/TensorMask/configs/tensormask_R_50_FPN_1x.yaml"
    )
    # cfg.merge_from_list(args.opts)
    cfg.MODEL.WEIGHTS = "/root/detectron2/projects/TensorMask/log_80_20/model_0024999.pth"

    return cfg
예제 #6
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_tensormask_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    if args.eval_only:
        cfg.MODEL.WEIGHTS = "/root/detectron2/projects/TensorMask/log_50_50/model_0034999.pth"
        cfg.SOLVER.IMS_PER_BATCH = 6

    cfg.freeze()
    default_setup(cfg, args)
    return cfg