def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False): if not nd_prefix_name and not pretrained_model_name: # model name always has extension 'param' try: model_name, iteration_number = parse_input_model(input_model_name) except ValueError as err: raise Error( 'Input model name {} is not in an expected format, cannot extract iteration number. ' + refer_to_faq_msg(48), input_model_name) if input_names: model_params = load_params(input_model_name, data_names=input_names.split(',')) else: model_params = load_params(input_model_name) elif nd_prefix_name and pretrained_model_name and input_symbol: model_name, iteration_number = parse_input_model(pretrained_model_name) model_name = '-'.join(input_symbol.split('-')[:-1]) model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names) else: raise Error( "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " + refer_to_faq_msg(81)) model_nodes = load_symbol_nodes(model_name, input_symbol, legacy_mxnet_model) return model_nodes, model_params, model_name, iteration_number
def build_params_file(nd_prefix_name: str = '', pretrained_model: str = '', input_names: str = ''): path_wo_ext = '.'.join(pretrained_model.split('.')[:-1]) pretrained_model_name_w_iter = path_wo_ext.split(os.sep)[-1] pretrained_model_name = '-'.join(path_wo_ext.split('-')[:-1]) iteration_number = int(pretrained_model_name_w_iter.split('-')[-1]) files_dir = os.path.dirname(pretrained_model) if input_names: model_params = load_params(pretrained_model, data_names=input_names.split(',')) else: model_params = load_params(pretrained_model) pretrained_params = mx.nd.load( pretrained_model) if pretrained_model_name else None nd_args = mx.nd.load(os.path.join( files_dir, '%s_args.nd' % nd_prefix_name)) if nd_prefix_name else None nd_auxs = mx.nd.load(os.path.join( files_dir, '%s_auxs.nd' % nd_prefix_name)) if nd_prefix_name else None nd_args = add_pretrained_model(pretrained_params, nd_args, pretrained_model_name, iteration_number, input_names) model_params._arg_params = nd_args model_params._aux_params = nd_auxs model_params._param_names = list(nd_args.keys()) model_params._aux_names = list(nd_auxs.keys()) return model_params
def test_load_symbol_nodes_from_args_nd(self, mock_nd_load): mock_nd_load.return_value = {'conv0_weight': mx.nd.array([1, 2], dtype='float32'), 'conv1_weight': mx.nd.array([2, 3], dtype='float32')} model_params = load_params("args_model.nd", data_names=('data1', 'data2')) self.assertTrue('conv0_weight' in model_params._param_names) self.assertTrue('conv1_weight' in model_params._param_names) self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist()) self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
def test_load_symbol_nodes_from_auxs_nd(self, mock_nd_load): mock_nd_load.return_value = { 'bn_data_mean': mx.nd.array([5, 6], dtype='float32') } model_params = load_params("auxs_model.nd") self.assertTrue('bn_data_mean' in model_params._aux_names) self.assertEqual( [5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())
def test_load_symbol_nodes_from_params(self, mock_nd_load): mock_nd_load.return_value = {'arg:conv0_weight': mx.nd.array([1, 2], dtype='float32'), 'arg:conv1_weight': mx.nd.array([2, 3], dtype='float32'), 'aux:bn_data_mean': mx.nd.array([5, 6], dtype='float32')} model_params = load_params("model.params") self.assertTrue('conv0_weight' in model_params._param_names) self.assertTrue('conv1_weight' in model_params._param_names) self.assertTrue('bn_data_mean' in model_params._aux_names) self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist()) self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist()) self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())