Ejemplo n.º 1
0
def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '',
                    pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
    if not nd_prefix_name and not pretrained_model_name:
        # model name always has extension 'param'
        try:
            model_name, iteration_number = parse_input_model(input_model_name)
        except ValueError as err:
            raise Error(
                'Input model name {} is not in an expected format, cannot extract iteration number. ' +
                refer_to_faq_msg(48),
                input_model_name)

        if input_names:
            model_params = load_params(input_model_name, data_names=input_names.split(','))
        else:
            model_params = load_params(input_model_name)

    elif nd_prefix_name and pretrained_model_name and input_symbol:
        model_name, iteration_number = parse_input_model(pretrained_model_name)
        model_name = '-'.join(input_symbol.split('-')[:-1])
        model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
    else:
        raise Error(
            "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
            refer_to_faq_msg(81))

    model_nodes = load_symbol_nodes(model_name, input_symbol, legacy_mxnet_model)

    return model_nodes, model_params, model_name, iteration_number
Ejemplo n.º 2
0
def build_params_file(nd_prefix_name: str = '',
                      pretrained_model: str = '',
                      input_names: str = ''):
    path_wo_ext = '.'.join(pretrained_model.split('.')[:-1])
    pretrained_model_name_w_iter = path_wo_ext.split(os.sep)[-1]
    pretrained_model_name = '-'.join(path_wo_ext.split('-')[:-1])
    iteration_number = int(pretrained_model_name_w_iter.split('-')[-1])
    files_dir = os.path.dirname(pretrained_model)

    if input_names:
        model_params = load_params(pretrained_model,
                                   data_names=input_names.split(','))
    else:
        model_params = load_params(pretrained_model)

    pretrained_params = mx.nd.load(
        pretrained_model) if pretrained_model_name else None
    nd_args = mx.nd.load(os.path.join(
        files_dir, '%s_args.nd' % nd_prefix_name)) if nd_prefix_name else None
    nd_auxs = mx.nd.load(os.path.join(
        files_dir, '%s_auxs.nd' % nd_prefix_name)) if nd_prefix_name else None
    nd_args = add_pretrained_model(pretrained_params, nd_args,
                                   pretrained_model_name, iteration_number,
                                   input_names)

    model_params._arg_params = nd_args
    model_params._aux_params = nd_auxs
    model_params._param_names = list(nd_args.keys())
    model_params._aux_names = list(nd_auxs.keys())
    return model_params
Ejemplo n.º 3
0
 def test_load_symbol_nodes_from_args_nd(self, mock_nd_load):
     mock_nd_load.return_value = {'conv0_weight': mx.nd.array([1, 2], dtype='float32'),
                                  'conv1_weight': mx.nd.array([2, 3], dtype='float32')}
     model_params = load_params("args_model.nd", data_names=('data1', 'data2'))
     self.assertTrue('conv0_weight' in model_params._param_names)
     self.assertTrue('conv1_weight' in model_params._param_names)
     self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
     self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
Ejemplo n.º 4
0
 def test_load_symbol_nodes_from_auxs_nd(self, mock_nd_load):
     mock_nd_load.return_value = {
         'bn_data_mean': mx.nd.array([5, 6], dtype='float32')
     }
     model_params = load_params("auxs_model.nd")
     self.assertTrue('bn_data_mean' in model_params._aux_names)
     self.assertEqual(
         [5., 6.],
         model_params._aux_params['bn_data_mean'].asnumpy().tolist())
Ejemplo n.º 5
0
 def test_load_symbol_nodes_from_params(self, mock_nd_load):
     mock_nd_load.return_value = {'arg:conv0_weight': mx.nd.array([1, 2], dtype='float32'),
                                  'arg:conv1_weight': mx.nd.array([2, 3], dtype='float32'),
                                  'aux:bn_data_mean': mx.nd.array([5, 6], dtype='float32')}
     model_params = load_params("model.params")
     self.assertTrue('conv0_weight' in model_params._param_names)
     self.assertTrue('conv1_weight' in model_params._param_names)
     self.assertTrue('bn_data_mean' in model_params._aux_names)
     self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
     self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
     self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())