Пример #1
0
    def _prepare_module(task_id, symbol, ctx_config, data_names, label_names,
                        resume_config):
        if not resume_config['is_resume'] == '0':
            return Module(symbol=symbol,
                          context=Executor._prepare_ctx(ctx_config),
                          data_names=data_names,
                          label_names=label_names,
                          logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                            log_to_console=False,
                                            log_to_file=True))
        else:
            ckp = resume_config['ckp']
            prefix = ckp['prefix']
            epoch = ckp['epoch']
            params_path = osp.join(params_root_path,
                                   '%s-%04d.params' % (prefix, epoch))
            # Copyed from MXNet

            # Licensed to the Apache Software Foundation (ASF) under one
            # or more contributor license agreements.  See the NOTICE file
            # distributed with this work for additional information
            # regarding copyright ownership.  The ASF licenses this file
            # to you under the Apache License, Version 2.0 (the
            # "License"); you may not use this file except in compliance
            # with the License.  You may obtain a copy of the License at
            #
            #   http://www.apache.org/licenses/LICENSE-2.0
            #
            # Unless required by applicable law or agreed to in writing,
            # software distributed under the License is distributed on an
            # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
            # KIND, either express or implied.  See the License for the
            # specific language governing permissions and limitations
            # under the License.
            save_dict = nd.load(params_path)
            arg_params = {}
            aux_params = {}
            for k, v in save_dict.items():
                tp, name = k.split(':', 1)
                if tp == 'arg':
                    arg_params[name] = v
                if tp == 'aux':
                    aux_params[name] = v
            mod = Module(symbol=symbol,
                         context=Executor._prepare_ctx(ctx_config),
                         logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                           log_to_console=False,
                                           log_to_file=True))
            mod._arg_params = arg_params
            mod._aux_params = aux_params
            mod.params_initialized = True
            # TODO: There is a parameter named load_optimizer_states in Module.load
            return mod
Пример #2
0
    def load_check_point(sym_json_path, params_path, ctx_config_tuple,
                         task_id):
        ctx_config = list(ctx_config_tuple)
        # Copyed from MXNet

        # Licensed to the Apache Software Foundation (ASF) under one
        # or more contributor license agreements.  See the NOTICE file
        # distributed with this work for additional information
        # regarding copyright ownership.  The ASF licenses this file
        # to you under the Apache License, Version 2.0 (the
        # "License"); you may not use this file except in compliance
        # with the License.  You may obtain a copy of the License at
        #
        #   http://www.apache.org/licenses/LICENSE-2.0
        #
        # Unless required by applicable law or agreed to in writing,
        # software distributed under the License is distributed on an
        # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
        # KIND, either express or implied.  See the License for the
        # specific language governing permissions and limitations
        # under the License.
        if not isinstance(sym_json_path, sym.Symbol):
            symbol = sym.load(sym_json_path)
        else:
            # If sym_json_path is already an instance of mxnet.sym.Symbol
            symbol = sym_json_path
        save_dict = nd.load(params_path)
        arg_params = {}
        aux_params = {}
        for k, v in save_dict.items():
            tp, name = k.split(':', 1)
            if tp == 'arg':
                arg_params[name] = v
            if tp == 'aux':
                aux_params[name] = v
        mod = Module(symbol=symbol,
                     context=generate_ctx(ctx_config),
                     logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                       log_to_console=False,
                                       log_to_file=True))
        mod._arg_params = arg_params
        mod._aux_params = aux_params
        mod.params_initialized = True
        # TODO: There is a parameter named load_optimizer_states in Module.load
        return mod