예제 #1
0
파일: function.py 프로젝트: zfxu/Dragon
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    updater._prefix = meta_graph.name + '_'
    extra_arguments = updater._extra_kwargs
    extra_arguments['domain'] = updater._prefix
    parallel_arguments = {}

    # wrap hyper-parameters as Tensor for CC
    for k, v in updater._hyper_params.items():
        ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for tuple in updater._tuples:
        tensors = tuple[0]
        arguments = tuple[1]
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater._type
        _, u_target.name = GetOperatorName()
        for tensor in tensors:
            u_target.tensor.append(tensor)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
예제 #2
0
파일: function.py 프로젝트: k9sret/Dragon
def GraphDef_Update(meta_graph, updater):
    """Inject the update targets into GraphDef.

    The ``updater`` should generate update targets before.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    updater : BaseUpdater
        The updater.

    Returns
    -------
    None

    """
    if updater is None: return

    # use graph name if missing slot
    if updater._slot is None:
        updater._slot = meta_graph.name
    extra_arguments = updater._extra_kwargs
    extra_arguments['slot'] = updater._slot
    parallel_arguments = {}

    updater.register_in_workspace()

    # check data parallel if necessary
    if mpi.Is_Init():
        idx, group = mpi.AllowParallel()
        if idx != -1:
            parallel_arguments['parallel_mode'] = mpi.GetParallelMode()
            parallel_arguments['comm'], parallel_arguments['group'] \
                = mpi.CreateGroup(root=group[0], incl=group)
            parallel_arguments['root'] = group[0]
        for k, v in parallel_arguments.items():
            meta_graph.arg.add().CopyFrom(MakeArgument(k, v))

    for e in updater._param_group:
        pair, arguments = e
        kwargs = dict(arguments, **extra_arguments)
        u_target = pb.UpdateTarget()
        u_target.type = updater.type()
        _, u_target.name = GetOperatorName()
        for t in pair:
            u_target.tensor.append(t)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        meta_graph.u_target.extend([u_target])
예제 #3
0
파일: function.py 프로젝트: zfxu/Dragon
def GraphDef_Phase(meta_graph, targets):
    """Inject the phase into GraphDef.

    If existing gradients, we assume it should be ``TRAIN``, and vice versa.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.
    targets : list
        The solving targets.

    Returns
    -------
    None

    """
    phase = 'TEST'
    from dragon.core.scope import _PHASE_SCOPE
    if _PHASE_SCOPE != '':
        phase = _PHASE_SCOPE.upper()
    else:
        for target in targets:
            if len(target.grad_wrts) > 0:
                phase = 'TRAIN'
                break
    meta_graph.arg.extend([MakeArgument('phase', phase)])
예제 #4
0
def GraphDef_Update(graph_def, updater):
    """ generate all update targets for CC Graph """
    if updater is None: return

    updater._prefix = graph_def.name + '_'
    extra_kwargs = updater._extra_kwargs
    extra_kwargs['domain'] = updater._prefix

    # wrap hyper-parameters as Tensor for CC
    for k, v in updater._hyper_params.items():
        ws.FeedTensor(updater._prefix + k, np.array([v], dtype=np.float32))

    # check data parallel if necessary
    if mpi.is_init():
        idx, group = mpi.allow_parallel()
        if idx != -1:
            extra_kwargs['comm'], extra_kwargs['group'] \
                = mpi.group(root=group[0], incl=group)
            extra_kwargs['root'] = group[0]
            extra_kwargs['mode'] = mpi.get_parallel_mode()
            extra_kwargs['group_size'] = len(group)

    for tuple in updater._tuples:
        tensors = tuple[0]
        kwargs = tuple[1]
        kwargs = dict(kwargs, **extra_kwargs)
        u_target = pb.UpdateTarget()
        u_target.type = updater._type
        _, u_target.name = GetOperatorName()
        for tensor in tensors:
            u_target.tensor.append(tensor)
        for k, v in kwargs.items():
            u_target.arg.add().CopyFrom(MakeArgument(k, v))
        graph_def.u_target.extend([u_target])
예제 #5
0
파일: function.py 프로젝트: k9sret/Dragon
def GraphDef_Opt(meta_graph):
    """Inject the optimization options into GraphDef.

    Parameters
    ----------
    meta_graph : dragon_pb2.GraphDef
        The definition of meta graph.

    Returns
    -------
    None

    References
    ----------
    `config.SetDebugMode(*args, **kwargs)`_ - How the enable debug mode.

    `memonger.share_grads(*args, **kwargs)`_ - How the enable gradients sharing.

    """

    from dragon.config import option
    OX = 3 if option['share_grads'] else 2
    if option['debug_mode']: OX = 1
    meta_graph.arg.add().CopyFrom(MakeArgument('optimization_level', OX))
    meta_graph.graph_type = option['graph_type']
예제 #6
0
def GraphDef_Phase(graph_def, targets):
    phase = 'TEST'
    from dragon.core.scope import PHASE_SCOPE
    global PHASE_SCOPE
    if PHASE_SCOPE != '': phase = PHASE_SCOPE.upper()
    else:
        for target in targets:
            if len(target.grad_wrts) > 0:
                phase = 'TRAIN'
                break
    graph_def.arg.extend([MakeArgument('phase', phase)])