示例#1
0
def create_net(configs):
    '''
    the params file format: module_name + others + '-dddd.params'
    'dddd' represents 4 digits
    '''
    conf = net.NetConfig()
    conf.parse(configs)

    logger.info("[Python net_init] load configs: %s",
                configs,
                extra={"reqid": ""})

    try:
        params_file, sym_file, label_file = (conf.file_model, conf.file_symbol,
                                             conf.file_synset)
        os.rename(sym_file, sym_file + '-symbol.json')
        os.rename(params_file, sym_file + '-0000.params')

        ctx = mx.gpu() if conf.use_device == 'GPU' else mx.cpu()

        sym, arg_params, aux_params = mx.model.load_checkpoint(sym_file, 0)
        mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)

        logger.info(
            "config of width:{}, height:{}, value_mean:{}, value_std:{},batch_size:{}"
            .format(conf.image_width, conf.image_height, conf.value_mean,
                    conf.value_std, conf.batch_size),
            extra={"reqid": ""})

        default_width = 224
        if conf.image_width == 0 or conf.image_width == None:
            conf.image_width = default_width
        if conf.image_height == 0 or conf.image_height == None:
            conf.image_height = conf.image_width

        mod.bind(for_training=False,
                 data_shapes=[('data', (conf.batch_size, 3, conf.image_width,
                                        conf.image_height))],
                 label_shapes=mod._label_shapes)
        mod.set_params(arg_params, aux_params, allow_missing=True)

    except Exception as _e:
        logger.info("[Python net_init] failed: %s",
                    traceback.format_exc(),
                    extra={"reqid": ""})
        return {}, 599, str(_e)

    return {
        "net":
        dict(error='',
             labels=net.load_labels(label_file),
             image_width=conf.image_width,
             image_height=conf.image_height,
             mean_value=conf.value_mean,
             std_value=conf.value_std,
             batch_size=conf.batch_size,
             mod=mod)
    }, 0, ''
示例#2
0
def _load_cls(label_file):
    return tuple(e[-1] for e in net.load_labels(label_file))