Example #1
0
# coding=utf-8

from easydict import EasyDict

from mxnetseg.data import get_dataset_info
from mxnetseg.tools import weight_path, record_path

C = EasyDict()
config = C

# model name
C.model_name = 'fcn'
C.model_dir = weight_path(C.model_name)
C.record_dir = record_path(C.model_name)

# dataset:
# COCO, VOC2012, VOCAug, SBD, PContext,
# Cityscapes, CamVid, CamVidFull, Stanford, GATECH, KITTIZhang, KITTIXu, KITTIRos
# NYU, SiftFlow, SUNRGBD, ADE20K
C.data_name = 'Cityscapes'
C.crop = 768
C.base = 2048
C.data_path, C.nclass = get_dataset_info(C.data_name)

# network
C.backbone = 'resnet18'
C.pretrained_base = True
C.dilate = False
C.norm = 'sbn'
C.aux = False
C.aux_weight = .7 if C.aux else None
Example #2
0
def _validate_checkpoint(model_name: str, checkpoint: str):
    """get model params"""
    checkpoint = os.path.join(weight_path(model_name), checkpoint)
    if not os.path.isfile(checkpoint):
        raise RuntimeError(f"No model params found at {checkpoint}")
    return checkpoint