Exemplo n.º 1
0
    def backend_mapping(self, config):
        """Map config to specific backend.

        :param config: original config from config file
        :type config: Config or dict
        :return: config after mapping to backend
        :rtype: Config
        """
        origin_config = Config(copy.deepcopy(config))
        type = origin_config.type

        if type not in self.type_mapping_dict:
            return config
        params = origin_config.get('params', {})
        backend_config = Config()
        backend_config.type = self.type_mapping_dict[type][self.backend_type]
        backend_config.params = Config()

        mapping_params = self.params_mapping_dict.get(type, {})
        for key, value in params.items():
            if key in mapping_params:
                mapping_key = mapping_params[key][self.backend_type]
            else:
                mapping_key = None
            if mapping_key is not None:
                if isinstance(value, dict) and 'type' in value:
                    backend_config.params[mapping_key] = self.backend_mapping(
                        value)
                else:
                    backend_config.params[mapping_key] = value

        return Config(backend_config)
Exemplo n.º 2
0
def run_pipeline(load_special_lib_func=None):
    """Run pipeline."""
    args = _parse_args()
    _resume(args)
    _set_backend(args)
    _append_env()
    if load_special_lib_func:
        load_special_lib_func(args.config_file)
    config = Config(args.config_file)
    # load general
    if config.get("general"):
        General.from_dict(config.get("general"), skip_check=False)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(General.TF_CPP_MIN_LOG_LEVEL)
    if General.requires and not verify_requires(General.requires):
        return
    dict_args = vars(args)
    dict_args = _check_parse(dict_args)
    config = _modify_config(dict_args, config)
    _backup_config(args)
    _change_process_name()
    vega.run(config)
Exemplo n.º 3
0
 def __init__(self, **desc):
     """Initialize."""
     super(SimpleCnn, self).__init__()
     desc = Config(**desc)
     self.num_class = desc.num_class
     self.fp16 = desc.get('fp16', False)
     self.channels = desc.channels
     self.conv1 = ops.Conv2d(3, 32, padding=1, kernel_size=3)
     self.pool1 = ops.MaxPool2d(2, stride=2)
     self.blocks = self._blocks(self.channels, desc.blocks)
     self.pool2 = ops.MaxPool2d(2, stride=2)
     self.conv2 = ops.Conv2d(self.channels, 64, padding=1, kernel_size=3)
     self.global_conv = ops.Conv2d(64, 64, kernel_size=8, padding=0)
     self.view = ops.View()
     self.fc = ops.Linear(64, self.num_class)