def make_const_node(data: Tensor, name: str = None) -> NodeDef:
    """
    Create a TF graph node containing a constant value.
    The resulting node is equivalent to using `tf.constant` on the
    default graph.

    Args:
        data: Numpy-array containing the data, shape, and datatype
        name: Optional name of the node

    Returns:
        Graph node for adding to a TF Graph instance
    """
    dtype = as_dtype(data.dtype).as_datatype_enum
    tensor_content = data.tobytes()
    tensor_dim = [TensorShapeProto.Dim(size=size) for size in data.shape]
    tensor_shape = TensorShapeProto(dim=tensor_dim)
    tensor_proto = TensorProto(tensor_content=tensor_content,
                               tensor_shape=tensor_shape,
                               dtype=dtype)
    node_def = NodeDef(op='Const',
                       name=name or 'Const',
                       attr={
                           'value': AttrValue(tensor=tensor_proto),
                           'dtype': AttrValue(type=dtype)
                       })
    return node_def
Esempio n. 2
0
    def update_colocation_group(self, get_colocation_op):
        """
        Update operations colocated with master variables to be colocated with proxy variables.

        Args:
            get_colocation_op (Callable): fn that gets the current colocation ops
        """
        for op in self._graph_item.graph.get_operations():
            # Do not update shared node (including nodes in the optimizer)
            # Do not update operations within the variable scope of master var
            # Do not update the VarhandleOp itself
            if not op.name.startswith(AUTODIST_REPLICA_PREFIX) or \
                    op.name.startswith(self._optimizer_name_scope) or \
                    op.name.startswith(self._this_op.name + '/') or \
                    (op.name.startswith(self._this_op.name) and op.type == 'VarHandleOp'):
                continue
            new_colocation_group = []
            for colocation_group in op.colocation_groups():
                current_binding_op = get_colocation_op(colocation_group)
                if current_binding_op == self._this_op:
                    op_name_to_bind_to = (
                        COLOCATION_PREFIX +
                        as_bytes(self._proxy_vars[0].op.name))
                    new_colocation_group.append(op_name_to_bind_to)
                else:
                    new_colocation_group.append(colocation_group)
            op._set_attr(
                "_class",
                pb2_AttrValue(list=pb2_AttrValue.ListValue(
                    s=new_colocation_group)))
def make_op_node(op_name: Text, inputs: Inputs, name: Text = None) -> NodeDef:
    """
    Create a TF graph node given the operation, input, and a name.
    The resulting node definition won't include any operation-specific
    attributes. It returns a valid node for most operations, though.

    Args:
        op_name: Native TF operation name (e.g. "MatMul")
        inputs: Input node, node name, or list of inputs nodes or node names
        name: Node name in the graph, must be unique and defaults to the
              operation name

    Returns:
        TF graph node definition for the given operation, inputs, and name
    """
    input_list = inputs
    # convert scalar input into list
    if not isinstance(inputs, list):
        input_list = [input_list]
    # convert list items to strings
    for i, item in enumerate(input_list):
        if hasattr(item, 'name'):
            input_list[i] = item.name
    # generate node defintion
    dtype = dtypes.float32.as_datatype_enum
    node_def = NodeDef(op=op_name,
                       name=name or op_name,
                       attr={'T': AttrValue(type=dtype)})
    node_def.input.extend(input_list)
    return node_def
Esempio n. 4
0
 def _prune_colocation_groups(graph_item):
     for op in graph_item.graph.get_operations():
         # Now prune the graph to have the right colocation constraints
         colocation_groups = [(c, graph_item.get_colocation_op(c))
                              for c in op.colocation_groups()]
         # We don't want any colocation groups that are just this `op`
         colocation_groups = [(c, bind_op)
                              for (c, bind_op) in colocation_groups
                              if bind_op != op]
         if colocation_groups:
             device_to_bind_to = colocation_groups[-1][1].device
             new_colocation_groups = [
                 c for (c, op) in colocation_groups
                 if op.device == device_to_bind_to
             ]
             op._set_device(device_to_bind_to)
             op._set_attr(
                 "_class",
                 pb2_AttrValue(list=pb2_AttrValue.ListValue(
                     s=new_colocation_groups)))
         else:
             try:
                 if op.get_attr("_class"):
                     op._clear_attr("_class")
             except ValueError:
                 pass
Esempio n. 5
0
def update_colocation_group(ops, old_op, new_op):
    """
    For each op in ops, we replace the colocation group as old_op to colocation group as new_op.

    Args:
        ops (Iterable[Operation]): The operations to update
        old_op (Operation): The op having the old colocation group
        new_op (Operation): The op having the new colocation group
    """
    old_groups = old_op.colocation_groups() or [
        COLOCATION_PREFIX + as_bytes(new_op.name)
    ]
    new_groups = new_op.colocation_groups() or [
        COLOCATION_PREFIX + as_bytes(new_op.name)
    ]
    for op in ops:
        if op.colocation_groups() == old_groups:
            op._set_attr("_class",
                         AttrValue(list=AttrValue.ListValue(s=new_groups)))
            assert op.colocation_groups() == new_groups
Esempio n. 6
0
def name_attr_list():
    attr = {
        'float': AttrValue(f=3.14159),
        'list': AttrValue(list=AttrValue.ListValue(b=[True, False, True]))
    }
    return NameAttrList(name='test_name_attr_list', attr=attr)
Esempio n. 7
0
def bool_list():
    return AttrValue.ListValue(b=[True, False])
Esempio n. 8
0
def int_list():
    return AttrValue.ListValue(i=[1, 2, 3])