def ExportMetaGraph(graph_def): """Export the meta graph into a file under specific folder. You can set the exporting prefix by `config.ExportMetaGraph(prefix)`_. Parameters ---------- graph_def : GraphDef The definition of meta graph. Returns ------- None """ option = GetGlobalOptions() if option['export_meta_graph']: if not os.path.exists(option['export_meta_graph']): try: os.makedirs(option['export_meta_graph']) except Exception: raise ValueError('The given prefix is invalid.') path = os.path.join(option['export_meta_graph'], graph_def.name + '.metatxt') with open(path, 'w') as f: f.write(str(graph_def)) logging.info('Export meta graph into: {}'.format(path))
def CreateGraph(graph_def): """Create the graph in current workspace. Parameters ---------- graph_def : GraphDef The definition of meta graph. Returns ------- str The graph name to run. """ options = _cfg.GetGlobalOptions() if options['log_meta_graph']: print(graph_def) if options['export_meta_graph']: if not os.path.exists(options['export_meta_graph']): try: os.makedirs(options['export_meta_graph']) except Exception: raise ValueError('The given prefix is invalid.') path = os.path.join(options['export_meta_graph'], graph_def.name + '.metatxt') with open(path, 'w') as f: f.write(str(graph_def)) _logging.info('Export meta graph to: {}'.format(path)) return get_default_workspace().CreateGraph(_stringify_proto(graph_def), options['log_optimized_graph'])
def _inject_update_ops(graph_def, updater): """Inject the update ops GraphDef. The ``updater`` should generate update targets before. Parameters ---------- graph_def : GraphDef The definition of graph. updater : BaseUpdater The updater. Returns ------- None """ if updater is None: return updater.register_in_workspace() grads, update_ops = [], [] extra_arguments = updater._extra_kwargs extra_arguments['slot'] = updater._slot # Build update ops according to the updater for e in updater._param_group: (param, grad), arguments = e if _workspace.HasTensor(grad): grads.append(grad) arguments = dict(arguments, **extra_arguments) update_ops.append( _proto_utils.MakeOperatorDef( op_type=updater.type(), inputs=[grad], outputs=[param], name=_helper.OperatorHelper.get_name(), **arguments)) else: _logging.info('Skip to update Tensor({}).'.format(param)) # Check data parallel if necessary if _mpi.Is_Init(): (rank, group), arguments = _mpi.AllowParallel(), {} if rank != -1: arguments['mode'] = '%s_ALLREDUCE' % _mpi.GetParallelMode() arguments['root'], (arguments['comm'], arguments['group']) \ = group[0], _mpi.CreateGroup(root=group[0], incl=group) update_ops.insert( 0, _proto_utils.MakeOperatorDef( op_type='CollectiveUpdate', inputs=grads, outputs=grads, name=_helper.OperatorHelper.get_name(), **arguments)) graph_def.op.extend(update_ops)
def Snapshot( tensors, filename, prefix='', suffix='.bin', format='default', ): """Snapshot tensors into a binary file. Parameters ---------- tensors : list of Tensor or Tensor The tensors to be wrote. filename : str The name of this binary file. prefix : str The prefix of this binary file. suffix : str The suffix of this binary file. format : str The format of this binary file. Returns ------- None Notes ----- The full file path will be: ``prefix`` + ``filename`` + ``suffix``. Available formats: ['default', 'caffe']. """ file_path = prefix + filename + suffix if mpi.Is_Init(): if not mpi.AllowSnapshot(): return file_path = file_path + '.rank.{}'.format(mpi.Rank()) dir = os.path.split(file_path)[0] if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir) if format == 'default': state_dict = {} for tensor in tensors: state_dict[tensor.name] = FetchTensor(tensor) with open(file_path, 'wb') as f: pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL) logging.info('Snapshot Model@: ' + file_path) logging.info('Model Format: Pickle') elif format is 'caffe': names = [tensor.name for tensor in tensors] _C.Snapshot(file_path, names, 1) else: raise TypeError('Unknown binary format: {}'.format(format))
def Snapshot( tensors, filename, prefix='', suffix='.bin', format='pickle', ): """Serialize tensors into a binary file. The filename is formatted as: ``prefix`` + ``filename`` + ``suffix`` Parameters ---------- tensors : list of Tensor or Tensor The tensors to be wrote. filename : str The name of this binary file. prefix : str, optional, default='' The prefix of this binary file. suffix : str, optional, default='.bin' The suffix of this binary file. format : {'pickle', 'caffe'}, optional The format of this binary file. Returns ------- None """ file_path = prefix + filename + suffix if _mpi.Is_Init(): if not _mpi.AllowSnapshot(): return file_path = file_path + '.rank.{}'.format(_mpi.Rank()) dir = os.path.split(file_path)[0] if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir) if format == 'pickle': state_dict = {} for tensor in tensors: state_dict[tensor.name] = FetchTensor(tensor) with open(file_path, 'wb') as f: pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL) _logging.info('Snapshot Model@: ' + file_path) _logging.info('Model Format: Pickle') elif format == 'caffe': names = [tensor.name for tensor in tensors] get_default_workspace().Snapshot(file_path, names, 1) else: raise TypeError('Unknown binary format: ' + format)
def Restore(binary_file, format='pickle'): """Restore tensors from a binary file. Parameters ---------- binary_file : str The path of binary file. format : {'pickle', 'caffe'}, optional The format of this binary file. Returns ------- None """ assert os.path.exists(binary_file), \ 'Binary file({}) does not exist.'.format(binary_file) if format == 'pickle': try: state_dict = pickle.load(open(binary_file, 'rb')) except UnicodeDecodeError: state_dict = pickle.load(open(binary_file, 'rb'), encoding='iso-8859-1') _logging.info('Restore From Model@: ' + binary_file) _logging.info('Model Format: Pickle') for k, v in state_dict.items(): if HasTensor(k): FeedTensor(k, v) _logging.info('Tensor({}) is restored.'.format(k)) elif format == 'caffe': get_default_workspace().Restore(binary_file, 1) else: raise TypeError('Unknown binary format: ' + format)
def load_state_dict(self, state_dict, strict=True, verbose=True): if verbose: _logging.info('Load the state dict.') unexpected = [] own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: state_shape = own_state[name].shape param_shape = param.shape if state_shape != param_shape: raise ValueError( 'Size of state({}) is ({}), \n' 'While load from Size of ({}).'.format( name, ', '.join([str(d) for d in state_shape]), ', '.join([str(d) for d in param_shape]))) if isinstance(param, Tensor): own_state[name].copy_(param) elif isinstance(param, numpy.ndarray): _tensor_utils.SetArray(own_state[name], param) else: raise ValueError( 'Excepted the type of source state is either ' 'dragon.vm.torch.Tensor or numpy.ndarray, got {}.'. format(type(param))) if verbose: _logging.info('Tensor({}) loaded, Size: ({})'.format( name, ', '.join([str(d) for d in param_shape]))) else: unexpected.append(name) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) error_msg = '' if len(unexpected) > 0: error_msg += 'Unexpected key(s) in state_dict: {}.\n'.format( ', '.join('"{}"'.format(k) for k in unexpected)) if len(missing) > 0: error_msg += 'Missing key(s) in state_dict: {}.'.format( ', '.join('"{}"'.format(k) for k in missing)) if len(error_msg) > 0: raise KeyError(error_msg)
def export_to(self, name=None, export_dir='./'): """Export the meta graph of this defined function. Parameters ---------- export_dir : str The directory to export the meta text file. Returns ------- None """ if not os.path.exists(export_dir): try: os.makedirs(export_dir) except Exception: raise ValueError('The given directory can not be created.') meta_graph_copy = copy.deepcopy(self.meta_graph) meta_graph_copy.name = self.meta_graph.name if name is None else name file = os.path.join(export_dir, meta_graph_copy.name + '.metatxt') with open(file, 'w') as f: f.write(str(meta_graph_copy)) _logging.info('Export meta graph into: {}'.format(file))
def cleanup(): def terminate(processes): for process in processes: process.terminate() process.join() terminate(self._fetchers) if local_rank == 0: _logging.info('Terminate BlobFetcher.') terminate(self._transformers) if local_rank == 0: _logging.info('Terminate DataTransformer.') terminate(self._readers) if local_rank == 0: _logging.info('Terminate DataReader.')
def Restore(binary_file, format='default'): """Restore tensors from a binary file. Parameters ---------- binary_file : str The path of binary file. format : str The format of this binary file. Returns ------- None Notes ----- Available formats: ['default', 'caffe']. """ assert os.path.exists(binary_file), \ 'Binary file({}) does not exist.'.format(binary_file) if format == 'default': try: state_dict = pickle.load(open(binary_file, 'rb')) except UnicodeDecodeError: state_dict = pickle.load(open(binary_file, 'rb'), encoding='iso-8859-1') logging.info('Restore From Model@: ' + binary_file) logging.info('Model Format: Pickle') for k, v in state_dict.items(): if HasTensor(k): FeedTensor(k, v) logging.info('[Info]: Tensor({}) is restored.'.format(k)) elif format == 'caffe': # Caffe models can't save the tensor name # We simply use "layer_name/param:X" _C.Restore(binary_file, 1) else: raise TypeError('Unknown binary format: {}'.format(format))
def cleanup(): logging.info('Terminating DragonBoard......') self.terminate() self.join()