def main(): args = parse_args() from gluon.utils import prepare_model as prepare_model_gl prepare_model_gl(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", dtype=np.float32) from pytorch.utils import prepare_model as prepare_model_pt prepare_model_pt(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", use_cuda=False) from chainer_.utils import prepare_model as prepare_model_ch prepare_model_ch(model_name=args.model, use_pretrained=True, pretrained_model_file_path="") from tensorflow2.utils import prepare_model as prepare_model_tf2 prepare_model_tf2(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", use_cuda=False)
def prepare_dst_model(dst_fwk, dst_model, src_fwk, ctx, use_cuda): if dst_fwk == "gluon": from gluon.utils import prepare_model as prepare_model_gl dst_net = prepare_model_gl(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="", dtype=np.float32, tune_layers="", ctx=ctx) dst_params = dst_net._collect_params_with_prefix() dst_param_keys = list(dst_params.keys()) elif dst_fwk == "pytorch": from pytorch.utils import prepare_model as prepare_model_pt dst_net = prepare_model_pt(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="", use_cuda=use_cuda, use_data_parallel=False) dst_params = dst_net.state_dict() dst_param_keys = list(dst_params.keys()) if src_fwk != "pytorch": dst_param_keys = [ key for key in dst_param_keys if not key.endswith("num_batches_tracked") ] elif dst_fwk == "chainer": from chainer_.utils import prepare_model as prepare_model_ch dst_net = prepare_model_ch(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_params = {i[0]: i[1] for i in dst_net.namedparams()} dst_param_keys = list(dst_params.keys()) elif dst_fwk == "keras": from keras_.utils import prepare_model as prepare_model_ke dst_net = prepare_model_ke(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_param_keys = list(dst_net._arg_names) + list(dst_net._aux_names) dst_params = {} for layer in dst_net.layers: if layer.name: for weight in layer.weights: if weight.name: dst_params.setdefault(weight.name, []).append(weight) dst_params[weight.name] = (layer, weight) elif dst_fwk == "tensorflow": import tensorflow as tf from tensorflow_.utils import prepare_model as prepare_model_tf dst_net = prepare_model_tf(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_param_keys = [v.name for v in tf.global_variables()] dst_params = {v.name: v for v in tf.global_variables()} else: raise ValueError("Unsupported dst fwk: {}".format(dst_fwk)) return dst_params, dst_param_keys, dst_net
def prepare_src_model(src_fwk, src_model, src_params_file_path, dst_fwk, ctx, use_cuda): ext_src_param_keys = None ext_src_param_keys2 = None if src_fwk == "gluon": from gluon.utils import prepare_model as prepare_model_gl src_net = prepare_model_gl( model_name=src_model, use_pretrained=False, pretrained_model_file_path=src_params_file_path, dtype=np.float32, tune_layers="", ctx=ctx) src_params = src_net._collect_params_with_prefix() src_param_keys = list(src_params.keys()) if src_model in ["resnet50_v1", "resnet101_v1", "resnet152_v1"]: src_param_keys = [ key for key in src_param_keys if not (key.startswith("features.") and key.endswith(".bias")) ] if src_model in [ "resnet18_v2", "resnet34_v2", "resnet50_v2", "resnet101_v2", "resnet152_v2" ]: src_param_keys = src_param_keys[4:] if dst_fwk == "chainer": src_param_keys_ = src_param_keys.copy() src_param_keys = [ key for key in src_param_keys_ if (not key.endswith(".running_mean")) and ( not key.endswith(".running_var")) ] ext_src_param_keys = [ key for key in src_param_keys_ if (key.endswith(".running_mean")) or ( key.endswith(".running_var")) ] if src_model in ["condensenet74_c4_g4", "condensenet74_c8_g8"]: src_param_keys_ = src_param_keys.copy() src_param_keys = [ key for key in src_param_keys_ if (not key.endswith(".index")) ] ext_src_param_keys2 = [ key for key in src_param_keys_ if (key.endswith(".index")) ] elif src_fwk == "pytorch": from pytorch.utils import prepare_model as prepare_model_pt src_net = prepare_model_pt( model_name=src_model, use_pretrained=False, pretrained_model_file_path=src_params_file_path, use_cuda=use_cuda, use_data_parallel=False) src_params = src_net.state_dict() src_param_keys = list(src_params.keys()) if dst_fwk != "pytorch": src_param_keys = [ key for key in src_param_keys if not key.endswith("num_batches_tracked") ] if src_model in ["oth_shufflenetv2_wd2"]: src_param_keys = [ key for key in src_param_keys if not key.startswith("network.0.") ] elif src_fwk == "mxnet": src_sym, src_arg_params, src_aux_params = mx.model.load_checkpoint( prefix=src_params_file_path, epoch=0) src_params = {} src_params.update(src_arg_params) src_params.update(src_aux_params) src_param_keys = list(src_params.keys()) elif src_fwk == "tensorflow": # import tensorflow as tf # from tensorflow_.utils import prepare_model as prepare_model_tf # src_net = prepare_model_tf( # model_name=src_model, # classes=num_classes, # use_pretrained=False, # pretrained_model_file_path=src_params_file_path) # src_param_keys = [v.name for v in tf.global_variables()] # src_params = {v.name: v for v in tf.global_variables()} src_net = None src_params = dict(np.load(src_params_file_path)) src_param_keys = list(src_params.keys()) else: raise ValueError("Unsupported src fwk: {}".format(src_fwk)) return src_params, src_param_keys, ext_src_param_keys, ext_src_param_keys2