def make_const_node(data: Tensor, name: str = None) -> NodeDef: """ Create a TF graph node containing a constant value. The resulting node is equivalent to using `tf.constant` on the default graph. Args: data: Numpy-array containing the data, shape, and datatype name: Optional name of the node Returns: Graph node for adding to a TF Graph instance """ dtype = as_dtype(data.dtype).as_datatype_enum tensor_content = data.tobytes() tensor_dim = [TensorShapeProto.Dim(size=size) for size in data.shape] tensor_shape = TensorShapeProto(dim=tensor_dim) tensor_proto = TensorProto(tensor_content=tensor_content, tensor_shape=tensor_shape, dtype=dtype) node_def = NodeDef(op='Const', name=name or 'Const', attr={ 'value': AttrValue(tensor=tensor_proto), 'dtype': AttrValue(type=dtype) }) return node_def
def update_colocation_group(self, get_colocation_op): """ Update operations colocated with master variables to be colocated with proxy variables. Args: get_colocation_op (Callable): fn that gets the current colocation ops """ for op in self._graph_item.graph.get_operations(): # Do not update shared node (including nodes in the optimizer) # Do not update operations within the variable scope of master var # Do not update the VarhandleOp itself if not op.name.startswith(AUTODIST_REPLICA_PREFIX) or \ op.name.startswith(self._optimizer_name_scope) or \ op.name.startswith(self._this_op.name + '/') or \ (op.name.startswith(self._this_op.name) and op.type == 'VarHandleOp'): continue new_colocation_group = [] for colocation_group in op.colocation_groups(): current_binding_op = get_colocation_op(colocation_group) if current_binding_op == self._this_op: op_name_to_bind_to = ( COLOCATION_PREFIX + as_bytes(self._proxy_vars[0].op.name)) new_colocation_group.append(op_name_to_bind_to) else: new_colocation_group.append(colocation_group) op._set_attr( "_class", pb2_AttrValue(list=pb2_AttrValue.ListValue( s=new_colocation_group)))
def make_op_node(op_name: Text, inputs: Inputs, name: Text = None) -> NodeDef: """ Create a TF graph node given the operation, input, and a name. The resulting node definition won't include any operation-specific attributes. It returns a valid node for most operations, though. Args: op_name: Native TF operation name (e.g. "MatMul") inputs: Input node, node name, or list of inputs nodes or node names name: Node name in the graph, must be unique and defaults to the operation name Returns: TF graph node definition for the given operation, inputs, and name """ input_list = inputs # convert scalar input into list if not isinstance(inputs, list): input_list = [input_list] # convert list items to strings for i, item in enumerate(input_list): if hasattr(item, 'name'): input_list[i] = item.name # generate node defintion dtype = dtypes.float32.as_datatype_enum node_def = NodeDef(op=op_name, name=name or op_name, attr={'T': AttrValue(type=dtype)}) node_def.input.extend(input_list) return node_def
def _prune_colocation_groups(graph_item): for op in graph_item.graph.get_operations(): # Now prune the graph to have the right colocation constraints colocation_groups = [(c, graph_item.get_colocation_op(c)) for c in op.colocation_groups()] # We don't want any colocation groups that are just this `op` colocation_groups = [(c, bind_op) for (c, bind_op) in colocation_groups if bind_op != op] if colocation_groups: device_to_bind_to = colocation_groups[-1][1].device new_colocation_groups = [ c for (c, op) in colocation_groups if op.device == device_to_bind_to ] op._set_device(device_to_bind_to) op._set_attr( "_class", pb2_AttrValue(list=pb2_AttrValue.ListValue( s=new_colocation_groups))) else: try: if op.get_attr("_class"): op._clear_attr("_class") except ValueError: pass
def update_colocation_group(ops, old_op, new_op): """ For each op in ops, we replace the colocation group as old_op to colocation group as new_op. Args: ops (Iterable[Operation]): The operations to update old_op (Operation): The op having the old colocation group new_op (Operation): The op having the new colocation group """ old_groups = old_op.colocation_groups() or [ COLOCATION_PREFIX + as_bytes(new_op.name) ] new_groups = new_op.colocation_groups() or [ COLOCATION_PREFIX + as_bytes(new_op.name) ] for op in ops: if op.colocation_groups() == old_groups: op._set_attr("_class", AttrValue(list=AttrValue.ListValue(s=new_groups))) assert op.colocation_groups() == new_groups
def name_attr_list(): attr = { 'float': AttrValue(f=3.14159), 'list': AttrValue(list=AttrValue.ListValue(b=[True, False, True])) } return NameAttrList(name='test_name_attr_list', attr=attr)
def bool_list(): return AttrValue.ListValue(b=[True, False])
def int_list(): return AttrValue.ListValue(i=[1, 2, 3])