def main(): args = parse_args() from gluon.utils import prepare_model as prepare_model_gl prepare_model_gl(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", dtype=np.float32) from pytorch.utils import prepare_model as prepare_model_pt prepare_model_pt(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", use_cuda=False) from chainer_.utils import prepare_model as prepare_model_ch prepare_model_ch(model_name=args.model, use_pretrained=True, pretrained_model_file_path="") from tensorflow2.utils import prepare_model as prepare_model_tf2 prepare_model_tf2(model_name=args.model, use_pretrained=True, pretrained_model_file_path="", use_cuda=False)
def prepare_dst_model(dst_fwk, dst_model, src_fwk, ctx, use_cuda): if dst_fwk == "gluon": from gluon.utils import prepare_model as prepare_model_gl dst_net = prepare_model_gl(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="", dtype=np.float32, tune_layers="", ctx=ctx) dst_params = dst_net._collect_params_with_prefix() dst_param_keys = list(dst_params.keys()) elif dst_fwk == "pytorch": from pytorch.utils import prepare_model as prepare_model_pt dst_net = prepare_model_pt(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="", use_cuda=use_cuda, use_data_parallel=False) dst_params = dst_net.state_dict() dst_param_keys = list(dst_params.keys()) if src_fwk != "pytorch": dst_param_keys = [ key for key in dst_param_keys if not key.endswith("num_batches_tracked") ] elif dst_fwk == "chainer": from chainer_.utils import prepare_model as prepare_model_ch dst_net = prepare_model_ch(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_params = {i[0]: i[1] for i in dst_net.namedparams()} dst_param_keys = list(dst_params.keys()) elif dst_fwk == "keras": from keras_.utils import prepare_model as prepare_model_ke dst_net = prepare_model_ke(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_param_keys = list(dst_net._arg_names) + list(dst_net._aux_names) dst_params = {} for layer in dst_net.layers: if layer.name: for weight in layer.weights: if weight.name: dst_params.setdefault(weight.name, []).append(weight) dst_params[weight.name] = (layer, weight) elif dst_fwk == "tensorflow": import tensorflow as tf from tensorflow_.utils import prepare_model as prepare_model_tf dst_net = prepare_model_tf(model_name=dst_model, use_pretrained=False, pretrained_model_file_path="") dst_param_keys = [v.name for v in tf.global_variables()] dst_params = {v.name: v for v in tf.global_variables()} else: raise ValueError("Unsupported dst fwk: {}".format(dst_fwk)) return dst_params, dst_param_keys, dst_net