Beispiel #1
0
def ExportMetaGraph(graph_def):
    """Export the meta graph into a file under specific folder.

    You can set the exporting prefix by `config.ExportMetaGraph(prefix)`_.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of meta graph.

    Returns
    -------
    None

    """
    option = GetGlobalOptions()
    if option['export_meta_graph']:
        if not os.path.exists(option['export_meta_graph']):
            try:
                os.makedirs(option['export_meta_graph'])
            except Exception:
                raise ValueError('The given prefix is invalid.')

        path = os.path.join(option['export_meta_graph'],
                            graph_def.name + '.metatxt')

        with open(path, 'w') as f:
            f.write(str(graph_def))
        logging.info('Export meta graph into: {}'.format(path))
Beispiel #2
0
def CreateGraph(graph_def):
    """Create the graph in current workspace.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of meta graph.

    Returns
    -------
    str
        The graph name to run.

    """
    options = _cfg.GetGlobalOptions()
    if options['log_meta_graph']: print(graph_def)
    if options['export_meta_graph']:
        if not os.path.exists(options['export_meta_graph']):
            try:
                os.makedirs(options['export_meta_graph'])
            except Exception:
                raise ValueError('The given prefix is invalid.')
        path = os.path.join(options['export_meta_graph'],
                            graph_def.name + '.metatxt')
        with open(path, 'w') as f:
            f.write(str(graph_def))
        _logging.info('Export meta graph to: {}'.format(path))
    return get_default_workspace().CreateGraph(_stringify_proto(graph_def),
                                               options['log_optimized_graph'])
Beispiel #3
0
def _inject_update_ops(graph_def, updater):
    """Inject the update ops GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    graph_def : GraphDef
        The definition of graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return
    updater.register_in_workspace()

    grads, update_ops = [], []
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot

    # Build update ops according to the updater
    for e in updater._param_group:
        (param, grad), arguments = e
        if _workspace.HasTensor(grad):
            grads.append(grad)
            arguments = dict(arguments, **extra_arguments)
            update_ops.append(
                _proto_utils.MakeOperatorDef(
                    op_type=updater.type(),
                    inputs=[grad],
                    outputs=[param],
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))
        else:
            _logging.info('Skip to update Tensor({}).'.format(param))

    # Check data parallel if necessary
    if _mpi.Is_Init():
        (rank, group), arguments = _mpi.AllowParallel(), {}
        if rank != -1:
            arguments['mode'] = '%s_ALLREDUCE' % _mpi.GetParallelMode()
            arguments['root'], (arguments['comm'], arguments['group']) \
                = group[0], _mpi.CreateGroup(root=group[0], incl=group)
            update_ops.insert(
                0,
                _proto_utils.MakeOperatorDef(
                    op_type='CollectiveUpdate',
                    inputs=grads,
                    outputs=grads,
                    name=_helper.OperatorHelper.get_name(),
                    **arguments))

    graph_def.op.extend(update_ops)
Beispiel #4
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='default',
):
    """Snapshot tensors into a binary file.

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str
        The prefix of this binary file.
    suffix : str
        The suffix of this binary file.
    format : str
        The format of this binary file.

    Returns
    -------
    None

    Notes
    -----
    The full file path will be:  ``prefix`` + ``filename`` + ``suffix``.

    Available formats: ['default', 'caffe'].

    """
    file_path = prefix + filename + suffix
    if mpi.Is_Init():
        if not mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'default':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        logging.info('Snapshot Model@: ' + file_path)
        logging.info('Model Format: Pickle')
    elif format is 'caffe':
        names = [tensor.name for tensor in tensors]
        _C.Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: {}'.format(format))
Beispiel #5
0
def Snapshot(
    tensors,
    filename,
    prefix='',
    suffix='.bin',
    format='pickle',
):
    """Serialize tensors into a binary file.

    The filename is formatted as:

        ``prefix`` + ``filename`` + ``suffix``

    Parameters
    ----------
    tensors : list of Tensor or Tensor
        The tensors to be wrote.
    filename : str
        The name of this binary file.
    prefix : str, optional, default=''
        The prefix of this binary file.
    suffix : str, optional, default='.bin'
        The suffix of this binary file.
    format : {'pickle', 'caffe'}, optional
        The format of this binary file.

    Returns
    -------
    None

    """
    file_path = prefix + filename + suffix

    if _mpi.Is_Init():
        if not _mpi.AllowSnapshot(): return
        file_path = file_path + '.rank.{}'.format(_mpi.Rank())

    dir = os.path.split(file_path)[0]
    if len(dir) > 0 and not os.path.exists(dir): os.makedirs(dir)

    if format == 'pickle':
        state_dict = {}
        for tensor in tensors:
            state_dict[tensor.name] = FetchTensor(tensor)
        with open(file_path, 'wb') as f:
            pickle.dump(state_dict, f, pickle.HIGHEST_PROTOCOL)
        _logging.info('Snapshot Model@: ' + file_path)
        _logging.info('Model Format: Pickle')
    elif format == 'caffe':
        names = [tensor.name for tensor in tensors]
        get_default_workspace().Snapshot(file_path, names, 1)
    else:
        raise TypeError('Unknown binary format: ' + format)
Beispiel #6
0
def Restore(binary_file, format='pickle'):
    """Restore tensors from a binary file.

    Parameters
    ----------
    binary_file : str
        The path of binary file.
    format : {'pickle', 'caffe'}, optional
        The format of this binary file.

    Returns
    -------
    None

    """
    assert os.path.exists(binary_file), \
        'Binary file({}) does not exist.'.format(binary_file)
    if format == 'pickle':
        try:
            state_dict = pickle.load(open(binary_file, 'rb'))
        except UnicodeDecodeError:
            state_dict = pickle.load(open(binary_file, 'rb'),
                                     encoding='iso-8859-1')
        _logging.info('Restore From Model@: ' + binary_file)
        _logging.info('Model Format: Pickle')
        for k, v in state_dict.items():
            if HasTensor(k):
                FeedTensor(k, v)
                _logging.info('Tensor({}) is restored.'.format(k))
    elif format == 'caffe':
        get_default_workspace().Restore(binary_file, 1)
    else:
        raise TypeError('Unknown binary format: ' + format)
Beispiel #7
0
 def load_state_dict(self, state_dict, strict=True, verbose=True):
     if verbose: _logging.info('Load the state dict.')
     unexpected = []
     own_state = self.state_dict()
     for name, param in state_dict.items():
         if name in own_state:
             state_shape = own_state[name].shape
             param_shape = param.shape
             if state_shape != param_shape:
                 raise ValueError(
                     'Size of state({}) is ({}), \n'
                     'While load from Size of ({}).'.format(
                         name, ', '.join([str(d) for d in state_shape]),
                         ', '.join([str(d) for d in param_shape])))
             if isinstance(param, Tensor):
                 own_state[name].copy_(param)
             elif isinstance(param, numpy.ndarray):
                 _tensor_utils.SetArray(own_state[name], param)
             else:
                 raise ValueError(
                     'Excepted the type of source state is either '
                     'dragon.vm.torch.Tensor or numpy.ndarray, got {}.'.
                     format(type(param)))
             if verbose:
                 _logging.info('Tensor({}) loaded, Size: ({})'.format(
                     name, ', '.join([str(d) for d in param_shape])))
         else:
             unexpected.append(name)
     if strict:
         missing = set(own_state.keys()) - set(state_dict.keys())
         error_msg = ''
         if len(unexpected) > 0:
             error_msg += 'Unexpected key(s) in state_dict: {}.\n'.format(
                 ', '.join('"{}"'.format(k) for k in unexpected))
         if len(missing) > 0:
             error_msg += 'Missing key(s) in state_dict: {}.'.format(
                 ', '.join('"{}"'.format(k) for k in missing))
         if len(error_msg) > 0:
             raise KeyError(error_msg)
Beispiel #8
0
    def export_to(self, name=None, export_dir='./'):
        """Export the meta graph of this defined function.

        Parameters
        ----------
        export_dir : str
            The directory to export the meta text file.

        Returns
        -------
        None

        """
        if not os.path.exists(export_dir):
            try:
                os.makedirs(export_dir)
            except Exception:
                raise ValueError('The given directory can not be created.')
        meta_graph_copy = copy.deepcopy(self.meta_graph)
        meta_graph_copy.name = self.meta_graph.name if name is None else name
        file = os.path.join(export_dir, meta_graph_copy.name + '.metatxt')
        with open(file, 'w') as f:
            f.write(str(meta_graph_copy))
        _logging.info('Export meta graph into: {}'.format(file))
Beispiel #9
0
 def cleanup():
     def terminate(processes):
         for process in processes:
             process.terminate()
             process.join()
     terminate(self._fetchers)
     if local_rank == 0: _logging.info('Terminate BlobFetcher.')
     terminate(self._transformers)
     if local_rank == 0: _logging.info('Terminate DataTransformer.')
     terminate(self._readers)
     if local_rank == 0: _logging.info('Terminate DataReader.')
Beispiel #10
0
def Restore(binary_file, format='default'):
    """Restore tensors from a binary file.

    Parameters
    ----------
    binary_file : str
        The path of binary file.
    format : str
        The format of this binary file.

    Returns
    -------
    None

    Notes
    -----
    Available formats: ['default', 'caffe'].

    """
    assert os.path.exists(binary_file), \
        'Binary file({}) does not exist.'.format(binary_file)

    if format == 'default':
        try:
            state_dict = pickle.load(open(binary_file, 'rb'))
        except UnicodeDecodeError:
            state_dict = pickle.load(open(binary_file, 'rb'),
                                     encoding='iso-8859-1')
        logging.info('Restore From Model@: ' + binary_file)
        logging.info('Model Format: Pickle')
        for k, v in state_dict.items():
            if HasTensor(k):
                FeedTensor(k, v)
                logging.info('[Info]: Tensor({}) is restored.'.format(k))
    elif format == 'caffe':
        # Caffe models can't save the tensor name
        # We simply use "layer_name/param:X"
        _C.Restore(binary_file, 1)
    else:
        raise TypeError('Unknown binary format: {}'.format(format))
Beispiel #11
0
 def cleanup():
     logging.info('Terminating DragonBoard......')
     self.terminate()
     self.join()