def load_fname(prefix, suffix=None): suffix = "." + suffix if suffix is not None else "" load_prefix = prefix + suffix names = list(utils.extend_fname(load_prefix, True)) names, ext_file = names[:-1], names[-1] (inputs_ext, ) = sim.load_ext(ext_file) dump_prefix = prefix + ".nnvm.compile" names.extend(utils.extend_fname(dump_prefix, False)) return names, inputs_ext
def save(self, model_name, datadir="./data"): # pylint: disable=unbalanced-tuple-unpacking sym_file, params_file, ext_file = \ utils.extend_fname(path.join(datadir, model_name), True) sim.save_ext(ext_file, self.old_names, self.th_dict, self.precs, self.scales) self.current_model.save(sym_file, params_file)
def load(model_name, datadir="./data"): # pylint: disable=unbalanced-tuple-unpacking sym_file, params_file, ext_file = \ utils.extend_fname(path.join(datadir, model_name), True) mrt = MRT(Model.load(sym_file, params_file)) mrt.old_names, mrt.th_dict, mrt.precs, mrt.scales = \ sim.load_ext(ext_file) return mrt
def dump(model, symbol, params): logger = logging.getLogger('model dump') prefix = "./data/tf_%s" % (model) sym_file, params_file = utils.extend_fname(prefix) with open(sym_file, "w") as f: f.write(symbol.tojson()) snames = [s.attr('name') for s in sutils.topo_sort(symbol)] items = dict(params.items()) for k, v in items.items(): if v.shape == (): print("%40s \t%s %s" % (k, type(v), v.shape), k in snames) assert k not in snames del params[k] nd.save(params_file, params) logger.info("Model successfully dumped to '%s'", sym_file)
def load_fname(version, suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" prefix = "./data/alexnet%s%s" % (version, suffix) return utils.extend_fname(prefix, with_ext=with_ext)
def load_fname(suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" prefix = "./data/trec%s" % (suffix) return utils.extend_fname(prefix, with_ext=with_ext)
def load_fname(version, suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" fname = "./data/tf_inception%s%s" % (version, suffix) return utils.extend_fname(fname, with_ext)
def validate_model(sym_path, prm_path, ctx, num_channel=3, input_size=224, batch_size=16, iter_num=10, ds_name='imagenet', from_scratch=0, lambd=None, dump_model=False, input_shape=None): from gluon_zoo import save_model flag = [False]*from_scratch + [True]*(2-from_scratch) model_name, _ = path.splitext(path.basename(sym_path)) model_dir = path.dirname(sym_path) input_shape = input_shape if input_shape else \ (batch_size, num_channel, input_size, input_size) logger = logging.getLogger("log.validate.%s"%model_name) if not path.exists(sym_path) or not path.exists(prm_path): save_model(model_name) model = Model.load(sym_path, prm_path) model.prepare(input_shape) # model = init(model, input_shape) print(tpass.collect_op_names(model.symbol, model.params)) data_iter_func = ds.data_iter(ds_name, batch_size, input_size=input_size) data, _ = data_iter_func() # prepare mrt = model.get_mrt() # mrt = MRT(model) # calibrate mrt.set_data(data) prefix = path.join(model_dir, model_name+'.mrt.dict') _, _, dump_ext = utils.extend_fname(prefix, True) if flag[0]: th_dict = mrt.calibrate(lambd=lambd) sim.save_ext(dump_ext, th_dict) else: (th_dict,) = sim.load_ext(dump_ext) mrt.set_th_dict(th_dict) mrt.set_input_prec(8) mrt.set_output_prec(8) if flag[1]: mrt.quantize() mrt.save(model_name+".mrt.quantize", datadir=model_dir) else: mrt = MRT.load(model_name+".mrt.quantize", datadir=model_dir) # dump model if dump_model: datadir = "/data/ryt" model_name = model_name + "_tfm" dump_shape = (1, num_channel, input_size, input_size) mrt.current_model.to_cvm( model_name, datadir=datadir, input_shape=input_shape) data = data[0].reshape(dump_shape) data = sim.load_real_data( data.astype("float64"), 'data', mrt.get_inputs_ext()) np.save(datadir+"/"+model_name+"/data.npy", data.astype('int8').asnumpy()) sys.exit(0) # validate org_model = load_model(Model.load(sym_path, prm_path), ctx) cvm_quantize = load_model( mrt.current_model, ctx, inputs_qext=mrt.get_inputs_ext()) utils.multi_validate(org_model, data_iter_func, cvm_quantize, iter_num=iter_num, logger=logging.getLogger('mrt.validate')) logger.info("test %s finished.", model_name)
def load_fname(suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" prefix = "./data/ssd_512_resnet50_v1_voc%s" % (suffix) return utils.extend_fname(prefix, with_ext)
def load_fname(version, suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" #prefix = "./data/cifar_resnext29_%s%s" % (version, suffix) prefix = "./data/quick_raw_qd_animal10_2_cifar_resnet%s%s" % (version, suffix) return utils.extend_fname(prefix, with_ext=with_ext)
def load_fname(suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" prefix = "./data/faster_rcnn_resnet50_v1b%s" % (suffix) return utils.extend_fname(prefix, with_ext)
def load_fname(version, suffix=None, with_ext=False): suffix = "." + suffix if suffix is not None else "" fname = "./data/a02_resnet-26_alpha-0.250%s%s" % (version, suffix) return utils.extend_fname(fname, with_ext)