Пример #1
0
def main():
    args = parse_args()

    from gluon.utils import prepare_model as prepare_model_gl
    prepare_model_gl(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="",
                     dtype=np.float32)

    from pytorch.utils import prepare_model as prepare_model_pt
    prepare_model_pt(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="",
                     use_cuda=False)

    from chainer_.utils import prepare_model as prepare_model_ch
    prepare_model_ch(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="")

    from tensorflow2.utils import prepare_model as prepare_model_tf2
    prepare_model_tf2(model_name=args.model,
                      use_pretrained=True,
                      pretrained_model_file_path="",
                      use_cuda=False)
Пример #2
0
def prepare_dst_model(dst_fwk, dst_model, src_fwk, ctx, use_cuda):

    if dst_fwk == "gluon":
        from gluon.utils import prepare_model as prepare_model_gl
        dst_net = prepare_model_gl(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="",
                                   dtype=np.float32,
                                   tune_layers="",
                                   ctx=ctx)
        dst_params = dst_net._collect_params_with_prefix()
        dst_param_keys = list(dst_params.keys())
    elif dst_fwk == "pytorch":
        from pytorch.utils import prepare_model as prepare_model_pt
        dst_net = prepare_model_pt(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="",
                                   use_cuda=use_cuda,
                                   use_data_parallel=False)
        dst_params = dst_net.state_dict()
        dst_param_keys = list(dst_params.keys())
        if src_fwk != "pytorch":
            dst_param_keys = [
                key for key in dst_param_keys
                if not key.endswith("num_batches_tracked")
            ]
    elif dst_fwk == "chainer":
        from chainer_.utils import prepare_model as prepare_model_ch
        dst_net = prepare_model_ch(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_params = {i[0]: i[1] for i in dst_net.namedparams()}
        dst_param_keys = list(dst_params.keys())
    elif dst_fwk == "keras":
        from keras_.utils import prepare_model as prepare_model_ke
        dst_net = prepare_model_ke(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_param_keys = list(dst_net._arg_names) + list(dst_net._aux_names)
        dst_params = {}
        for layer in dst_net.layers:
            if layer.name:
                for weight in layer.weights:
                    if weight.name:
                        dst_params.setdefault(weight.name, []).append(weight)
                        dst_params[weight.name] = (layer, weight)
    elif dst_fwk == "tensorflow":
        import tensorflow as tf
        from tensorflow_.utils import prepare_model as prepare_model_tf
        dst_net = prepare_model_tf(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_param_keys = [v.name for v in tf.global_variables()]
        dst_params = {v.name: v for v in tf.global_variables()}
    else:
        raise ValueError("Unsupported dst fwk: {}".format(dst_fwk))

    return dst_params, dst_param_keys, dst_net
Пример #3
0
def prepare_src_model(src_fwk, src_model, src_params_file_path, dst_fwk, ctx,
                      use_cuda):

    ext_src_param_keys = None
    ext_src_param_keys2 = None

    if src_fwk == "gluon":
        from gluon.utils import prepare_model as prepare_model_gl
        src_net = prepare_model_gl(
            model_name=src_model,
            use_pretrained=False,
            pretrained_model_file_path=src_params_file_path,
            dtype=np.float32,
            tune_layers="",
            ctx=ctx)
        src_params = src_net._collect_params_with_prefix()
        src_param_keys = list(src_params.keys())

        if src_model in ["resnet50_v1", "resnet101_v1", "resnet152_v1"]:
            src_param_keys = [
                key for key in src_param_keys
                if not (key.startswith("features.") and key.endswith(".bias"))
            ]

        if src_model in [
                "resnet18_v2", "resnet34_v2", "resnet50_v2", "resnet101_v2",
                "resnet152_v2"
        ]:
            src_param_keys = src_param_keys[4:]

        if dst_fwk == "chainer":
            src_param_keys_ = src_param_keys.copy()
            src_param_keys = [
                key for key in src_param_keys_
                if (not key.endswith(".running_mean")) and (
                    not key.endswith(".running_var"))
            ]
            ext_src_param_keys = [
                key for key in src_param_keys_
                if (key.endswith(".running_mean")) or (
                    key.endswith(".running_var"))
            ]
            if src_model in ["condensenet74_c4_g4", "condensenet74_c8_g8"]:
                src_param_keys_ = src_param_keys.copy()
                src_param_keys = [
                    key for key in src_param_keys_
                    if (not key.endswith(".index"))
                ]
                ext_src_param_keys2 = [
                    key for key in src_param_keys_ if (key.endswith(".index"))
                ]

    elif src_fwk == "pytorch":
        from pytorch.utils import prepare_model as prepare_model_pt
        src_net = prepare_model_pt(
            model_name=src_model,
            use_pretrained=False,
            pretrained_model_file_path=src_params_file_path,
            use_cuda=use_cuda,
            use_data_parallel=False)
        src_params = src_net.state_dict()
        src_param_keys = list(src_params.keys())
        if dst_fwk != "pytorch":
            src_param_keys = [
                key for key in src_param_keys
                if not key.endswith("num_batches_tracked")
            ]
        if src_model in ["oth_shufflenetv2_wd2"]:
            src_param_keys = [
                key for key in src_param_keys
                if not key.startswith("network.0.")
            ]

    elif src_fwk == "mxnet":
        src_sym, src_arg_params, src_aux_params = mx.model.load_checkpoint(
            prefix=src_params_file_path, epoch=0)
        src_params = {}
        src_params.update(src_arg_params)
        src_params.update(src_aux_params)
        src_param_keys = list(src_params.keys())
    elif src_fwk == "tensorflow":
        # import tensorflow as tf
        # from tensorflow_.utils import prepare_model as prepare_model_tf
        # src_net = prepare_model_tf(
        #     model_name=src_model,
        #     classes=num_classes,
        #     use_pretrained=False,
        #     pretrained_model_file_path=src_params_file_path)
        # src_param_keys = [v.name for v in tf.global_variables()]
        # src_params = {v.name: v for v in tf.global_variables()}

        src_net = None
        src_params = dict(np.load(src_params_file_path))
        src_param_keys = list(src_params.keys())
    else:
        raise ValueError("Unsupported src fwk: {}".format(src_fwk))

    return src_params, src_param_keys, ext_src_param_keys, ext_src_param_keys2