def broadcast_global_variables(root_rank): """Used to broadcast global variables""" variables = backend._get_variables(backend.get_graph()) candidate_vars = [] for v in variables: if getattr(v, "_keras_initialized", False): candidate_vars.append(v) op_list = [] if candidate_vars: for var in candidate_vars: inputs = [var] outputs = hccl_ops.broadcast(tensor=inputs, root_rank=root_rank) if outputs is not None: op_list.append(outputs[0].op) op_list.append(state_ops.assign(var, outputs[0])) return control_flow_ops.group(op_list)
def _broadcast_variables(session): op_list = [] variables = backend._get_variables(backend.get_graph()) graph_key = backend.get_graph()._graph_key candidate_vars = [] for v in variables: if getattr(v, "_keras_initialized", False): candidate_vars.append(v) if graph_key not in _keras_graph_key and candidate_vars: for var in candidate_vars: inputs = [var] outputs = hccl_ops.broadcast(tensor=inputs, root_rank=0) if outputs is not None: op_list.append(outputs[0].op) op_list.append(state_ops.assign(var, outputs[0])) session.run(control_flow_ops.group(op_list)) _keras_graph_key.append(graph_key)
def _wait_for_variable_initialization(session): """Utility to wait for variables to be initialized.""" all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access candidate_vars = [] for v in all_variables: if not getattr(v, '_keras_initialized', False): candidate_vars.append(v) if not candidate_vars: return while True: is_initialized = session.run( [variables.is_variable_initialized(v) for v in candidate_vars]) uninitialized_vars = [] for flag, v in zip(is_initialized, candidate_vars): if not flag: uninitialized_vars.append(v) v._keras_initialized = True # pylint: disable=protected-access if not uninitialized_vars: break
def _wait_for_variable_initialization(session): """Utility to wait for variables to be initialized.""" all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access candidate_vars = [] for v in all_variables: if not getattr(v, '_keras_initialized', False): candidate_vars.append(v) if not candidate_vars: return while True: is_initialized = session.run( [variables.is_variable_initialized(v) for v in candidate_vars]) uninitialized_vars = [] for flag, v in zip(is_initialized, candidate_vars): if not flag: uninitialized_vars.append(v) v._keras_initialized = True # pylint: disable=protected-access if not uninitialized_vars: break