Exemplo n.º 1
0
def broadcast_global_variables(root_rank):
    """Used to broadcast global variables"""
    variables = backend._get_variables(backend.get_graph())
    candidate_vars = []
    for v in variables:
        if getattr(v, "_keras_initialized", False):
            candidate_vars.append(v)
    op_list = []
    if candidate_vars:
        for var in candidate_vars:
            inputs = [var]
            outputs = hccl_ops.broadcast(tensor=inputs, root_rank=root_rank)
            if outputs is not None:
                op_list.append(outputs[0].op)
                op_list.append(state_ops.assign(var, outputs[0]))
    return control_flow_ops.group(op_list)
Exemplo n.º 2
0
def _broadcast_variables(session):
  op_list = []
  variables = backend._get_variables(backend.get_graph())
  graph_key = backend.get_graph()._graph_key
  candidate_vars = []
  for v in variables:
    if getattr(v, "_keras_initialized", False):
      candidate_vars.append(v)
  if graph_key not in _keras_graph_key and candidate_vars:
    for var in candidate_vars:
      inputs = [var]
      outputs = hccl_ops.broadcast(tensor=inputs, root_rank=0)
      if outputs is not None:
        op_list.append(outputs[0].op)
        op_list.append(state_ops.assign(var, outputs[0]))
    session.run(control_flow_ops.group(op_list))
  _keras_graph_key.append(graph_key)
Exemplo n.º 3
0
def _wait_for_variable_initialization(session):
    """Utility to wait for variables to be initialized."""
    all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
    candidate_vars = []
    for v in all_variables:
        if not getattr(v, '_keras_initialized', False):
            candidate_vars.append(v)

    if not candidate_vars:
        return

    while True:
        is_initialized = session.run(
            [variables.is_variable_initialized(v) for v in candidate_vars])
        uninitialized_vars = []
        for flag, v in zip(is_initialized, candidate_vars):
            if not flag:
                uninitialized_vars.append(v)
            v._keras_initialized = True  # pylint: disable=protected-access
        if not uninitialized_vars:
            break
def _wait_for_variable_initialization(session):
  """Utility to wait for variables to be initialized."""
  all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
  candidate_vars = []
  for v in all_variables:
    if not getattr(v, '_keras_initialized', False):
      candidate_vars.append(v)

  if not candidate_vars:
    return

  while True:
    is_initialized = session.run(
        [variables.is_variable_initialized(v) for v in candidate_vars])
    uninitialized_vars = []
    for flag, v in zip(is_initialized, candidate_vars):
      if not flag:
        uninitialized_vars.append(v)
      v._keras_initialized = True  # pylint: disable=protected-access
    if not uninitialized_vars:
      break