def _prepare_module(task_id, symbol, ctx_config, data_names, label_names, resume_config): if not resume_config['is_resume'] == '0': return Module(symbol=symbol, context=Executor._prepare_ctx(ctx_config), data_names=data_names, label_names=label_names, logger=get_logger('mxnet_logger[tid=%s]' % task_id, log_to_console=False, log_to_file=True)) else: ckp = resume_config['ckp'] prefix = ckp['prefix'] epoch = ckp['epoch'] params_path = osp.join(params_root_path, '%s-%04d.params' % (prefix, epoch)) # Copyed from MXNet # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. save_dict = nd.load(params_path) arg_params = {} aux_params = {} for k, v in save_dict.items(): tp, name = k.split(':', 1) if tp == 'arg': arg_params[name] = v if tp == 'aux': aux_params[name] = v mod = Module(symbol=symbol, context=Executor._prepare_ctx(ctx_config), logger=get_logger('mxnet_logger[tid=%s]' % task_id, log_to_console=False, log_to_file=True)) mod._arg_params = arg_params mod._aux_params = aux_params mod.params_initialized = True # TODO: There is a parameter named load_optimizer_states in Module.load return mod
def load_check_point(sym_json_path, params_path, ctx_config_tuple, task_id): ctx_config = list(ctx_config_tuple) # Copyed from MXNet # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. if not isinstance(sym_json_path, sym.Symbol): symbol = sym.load(sym_json_path) else: # If sym_json_path is already an instance of mxnet.sym.Symbol symbol = sym_json_path save_dict = nd.load(params_path) arg_params = {} aux_params = {} for k, v in save_dict.items(): tp, name = k.split(':', 1) if tp == 'arg': arg_params[name] = v if tp == 'aux': aux_params[name] = v mod = Module(symbol=symbol, context=generate_ctx(ctx_config), logger=get_logger('mxnet_logger[tid=%s]' % task_id, log_to_console=False, log_to_file=True)) mod._arg_params = arg_params mod._aux_params = aux_params mod.params_initialized = True # TODO: There is a parameter named load_optimizer_states in Module.load return mod