def _get_new_args(self, args, kwargs): # TODO: currently this follows Keras convention to treat the first dimension as batch dim # However, we should use the tf.function to handle the complete cases of input signatures _warn_msg = 'AutoDist treats the first dimension of autodist.function input with shape {} as batch dimenstion' # Insert placeholders in place of ndarrays args_with_ph = [] kwargs_with_ph = {} for i, arg in enumerate(args): if isinstance(arg, np.ndarray): logging.warning(_warn_msg.format(arg.shape)) ph = array_ops.placeholder(dtype=arg.dtype, shape=(None, *arg.shape[1:])) args_with_ph.append(ph) self._ph_feed_index[ph] = i else: args_with_ph.append(arg) for (k, v) in kwargs.items(): if isinstance(v, np.ndarray): logging.warning(_warn_msg.format(v.shape)) ph = array_ops.placeholder(dtype=v.dtype, shape=(None, *v.shape[1:])) kwargs_with_ph[k] = ph self._ph_feed_index[ph] = k else: kwargs_with_ph[k] = v return tuple(args_with_ph), kwargs_with_ph
def serialize(self, path): """Serialize a graph_item to a specific proto string down to a file path.""" item_def = graphitem_pb2.GraphItem() # GraphDef item_def.graph_def.Pack(self.graph.as_graph_def()) # Grad Target Pairs for k, v in self._grad_target_pairs.items(): if isinstance(k, tuple): k = ';'.join(k) item_def.grad_target_pairs[k] = v # Info def f(v, repeated_any): a = Any() a.Pack(v) repeated_any.append(a) for v in self.info.variables: f(v, item_def.info.variables) for v in self.info.savers: f(v, item_def.info.savers) item_def.info.table_initializers.extend(self.info.table_initializers) logging.warning( 'GraphItem currently does not serialize optimizer info, ' 'while optimizer info is only temporarily used for partitioner.') # Serialization item_def.SerializeToString() with open(path, "wb+") as f: f.write(item_def.SerializeToString())
def is_built(self): """ Whether the distributed graph is built for the most recent original graph. Returns: bool: True if the distributed graph is built by AutoDist """ if self._built: if ENV.AUTODIST_IS_TESTING.val and self._original_graph_item.graph.as_graph_def( ) != self._built: msg = 'Graph is modified after distributed session is created.' logging.warning(msg) raise RuntimeWarning(msg) return True return False
def _update_all_consumers(self): """Update all proxy variable consumers.""" for consumer_op in self._consumer_to_read_var_op: if consumer_op in self._proxy_var_init_ops: continue if consumer_op.name.startswith(AUTODIST_REPLICA_PREFIX): if len(self._proxy_vars) > 1: raise ValueError( 'Now we only create one proxy per variable at most...') self._update_consumer(self._proxy_vars[0], consumer_op) else: # TODO: Attention: ReadVarOp consumers include the "save". logging.warning( "Consumer %s of value of variable %s is a shared node, do not change to proxy variable" % (consumer_op.name, self._this_op.name))
def patch_var_reading(): """It only works with tf.gradients but not tape.gradients.""" def value(self): """A cached operation which reads the value of this variable.""" if self._cached_value is not None: return self._cached_value with ops.colocate_with(None, ignore_existing=True): with ops.device(self._handle.device): # return self._read_variable_op() # original line return self._graph_element setattr(ResourceVariable, 'value', value) logging.warning( 'Resource variable is patched ' 'to behave as ref (only on reading) to avoid multiple recv_tensor.' )
def _collect_sparse_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: raise NotImplementedError( 'Currently the collective NCCL AllGather is not supported in TensorFlow release.' 'Please choose another strategy.') conf = {} if self._spec: conf = {'communication_hint': self._spec} if self._compressor_type: logging.warning( 'AllGather currently does not support AutoDist compressor so it skips.' ) if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance indices_c_ops = grad.indices.consumers() indices_cc_ops = get_control_consumers(grad.indices.op) values_c_ops = grad.values.consumers() values_cc_ops = get_control_consumers(grad.values.op) with ops.name_scope(replica_prefix(i)): with ops.colocate_with(grad.indices.op): new_indices = collective_ops.all_gather( grad.indices, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-indices'), **conf) with ops.colocate_with(grad.values.op): new_values = collective_ops.all_gather( grad.values, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-values'), **conf) update_consumers(indices_c_ops, grad.indices, new_indices) update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) update_consumers(values_c_ops, grad.values, new_values) update_control_consumers(values_cc_ops, grad.values.op, new_values)
def __init__(self, config: synchronizers_pb2.AllReduceSynchronizer): self._spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Name( config.spec) if autodist.float_major_minor_tf_version < 1.15 or autodist.float_major_minor_tf_version < 2.1: logging.warning( 'Collective synchronizer spec "{}" a.k.a communication_hint has no effect ' 'until tensorflow-gpu 1.x>= 1.15 or 2.x>=2.1. It may cause error currently.' .format(self._spec)) self._spec = None self._compressor_type = synchronizers_pb2.AllReduceSynchronizer.Compressor.Name( config.compressor) # Collective ops within the same group will be merged by the scoped optimizer. # Normally the group index shall be smaller than the number of variables in the graph; this kernel assumes # the strategy will validate the group assignments are legitimate. self._group = config.group super().__init__()
def _check_partition_list(partition_list): if not partition_list: logging.warning('Partition list is empty.') return False all_one = True active_axis = 0 for p in partition_list: if p == 0: return False if p > 1: all_one = False active_axis += 1 if all_one: logging.warning('Partition list is trivial -- num_split is 1 on every axis.') return False if active_axis > 1: logging.warning('Currently AutoDist only support partitioning along one axis.') return False return True