Exemplo n.º 1
0
def main():
    args = parse_args()

    from gluon.utils import prepare_model as prepare_model_gl
    prepare_model_gl(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="",
                     dtype=np.float32)

    from pytorch.utils import prepare_model as prepare_model_pt
    prepare_model_pt(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="",
                     use_cuda=False)

    from chainer_.utils import prepare_model as prepare_model_ch
    prepare_model_ch(model_name=args.model,
                     use_pretrained=True,
                     pretrained_model_file_path="")

    from tensorflow2.utils import prepare_model as prepare_model_tf2
    prepare_model_tf2(model_name=args.model,
                      use_pretrained=True,
                      pretrained_model_file_path="",
                      use_cuda=False)
Exemplo n.º 2
0
def prepare_dst_model(dst_fwk, dst_model, src_fwk, ctx, use_cuda):

    if dst_fwk == "gluon":
        from gluon.utils import prepare_model as prepare_model_gl
        dst_net = prepare_model_gl(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="",
                                   dtype=np.float32,
                                   tune_layers="",
                                   ctx=ctx)
        dst_params = dst_net._collect_params_with_prefix()
        dst_param_keys = list(dst_params.keys())
    elif dst_fwk == "pytorch":
        from pytorch.utils import prepare_model as prepare_model_pt
        dst_net = prepare_model_pt(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="",
                                   use_cuda=use_cuda,
                                   use_data_parallel=False)
        dst_params = dst_net.state_dict()
        dst_param_keys = list(dst_params.keys())
        if src_fwk != "pytorch":
            dst_param_keys = [
                key for key in dst_param_keys
                if not key.endswith("num_batches_tracked")
            ]
    elif dst_fwk == "chainer":
        from chainer_.utils import prepare_model as prepare_model_ch
        dst_net = prepare_model_ch(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_params = {i[0]: i[1] for i in dst_net.namedparams()}
        dst_param_keys = list(dst_params.keys())
    elif dst_fwk == "keras":
        from keras_.utils import prepare_model as prepare_model_ke
        dst_net = prepare_model_ke(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_param_keys = list(dst_net._arg_names) + list(dst_net._aux_names)
        dst_params = {}
        for layer in dst_net.layers:
            if layer.name:
                for weight in layer.weights:
                    if weight.name:
                        dst_params.setdefault(weight.name, []).append(weight)
                        dst_params[weight.name] = (layer, weight)
    elif dst_fwk == "tensorflow":
        import tensorflow as tf
        from tensorflow_.utils import prepare_model as prepare_model_tf
        dst_net = prepare_model_tf(model_name=dst_model,
                                   use_pretrained=False,
                                   pretrained_model_file_path="")
        dst_param_keys = [v.name for v in tf.global_variables()]
        dst_params = {v.name: v for v in tf.global_variables()}
    else:
        raise ValueError("Unsupported dst fwk: {}".format(dst_fwk))

    return dst_params, dst_param_keys, dst_net