def _map_subgraph(sources, sinks):
    """Captures subgraph between sources and sinks.

  Walk a Graph backwards from `sinks` to `sources` and returns any extra sources
  encountered in the subgraph that were not specified in `sources`.

  Arguments:
    sources:  An iterable of Tensors where subgraph extraction should stop.
    sinks:  An iterable of Operations where the subgraph terminates.

  Returns:
    The set of placeholders upon which `sinks` depend and are not in `sources`.
  """
    stop_at_tensors = object_identity.ObjectIdentitySet(sources)
    ops_to_visit = object_identity.ObjectIdentitySet(sinks)
    visited_ops = object_identity.ObjectIdentitySet()
    potential_extra_sources = object_identity.ObjectIdentitySet()
    while ops_to_visit:
        op = ops_to_visit.pop()
        visited_ops.add(op)

        if op.type == 'Placeholder':
            potential_extra_sources.update(op.outputs)

        input_ops = [t.op for t in op.inputs if t not in stop_at_tensors]
        for input_op in itertools.chain(input_ops, op.control_inputs):
            if input_op not in visited_ops:
                ops_to_visit.add(input_op)

    return potential_extra_sources.difference(sources)
示例#2
0
def retrieve_sources(sinks, ignore_control_dependencies=False):
    """Captures subgraph between sources and sinks.

  Walk a Graph backwards from `sinks` and return any sources encountered in the
  subgraph. This util is refactored from `_map_subgraph` in
  tensorflow/.../ops/op_selector.py.

  Arguments:
    sinks:  An iterable of Operations where the subgraph terminates.
    ignore_control_dependencies: (Optional) If `True`, ignore any
      `control_inputs` for all ops while walking the graph.

  Returns:
    The set of placeholders upon which `sinks` depend. This could also contain
    placeholders representing `captures` in the graph.
  """
    stop_at_tensors = object_identity.ObjectIdentitySet()
    ops_to_visit = object_identity.ObjectIdentitySet(sinks)
    visited_ops = object_identity.ObjectIdentitySet()
    potential_extra_sources = object_identity.ObjectIdentitySet()
    while ops_to_visit:
        op = ops_to_visit.pop()
        visited_ops.add(op)

        if op.type == 'Placeholder':
            potential_extra_sources.update(op.outputs)

        input_ops = [t.op for t in op.inputs if t not in stop_at_tensors]
        if not ignore_control_dependencies:
            input_ops = itertools.chain(input_ops, op.control_inputs)
        for input_op in input_ops:
            if input_op not in visited_ops:
                ops_to_visit.add(input_op)

    return potential_extra_sources
示例#3
0
    def _call_func(self, args, kwargs):
        try:
            vars_at_start = self._template_store.variables()
            trainable_at_start = self._template_store.trainable_variables()
            if self._variables_created:
                result = self._func(*args, **kwargs)
            else:
                # The first time we run, restore variables if necessary (via
                # Trackable).
                with trackable_util.capture_dependencies(template=self):
                    result = self._func(*args, **kwargs)

            if self._variables_created:
                # Variables were previously created, implying this is not the first
                # time the template has been called. Check to make sure that no new
                # trainable variables were created this time around.
                trainable_variables = self._template_store.trainable_variables(
                )
                # If a variable that we intend to train is created as a side effect
                # of creating a template, then that is almost certainly an error.
                if len(trainable_at_start) != len(trainable_variables):
                    raise ValueError(
                        "Trainable variable created when calling a template "
                        "after the first time, perhaps you used tf.Variable "
                        "when you meant tf.get_variable: %s" % list(
                            object_identity.ObjectIdentitySet(
                                trainable_variables) - object_identity.
                            ObjectIdentitySet(trainable_at_start)))

                # Non-trainable tracking variables are a legitimate reason why a new
                # variable would be created, but it is a relatively advanced use-case,
                # so log it.
                variables = self._template_store.variables()
                if len(vars_at_start) != len(variables):
                    logging.info(
                        "New variables created when calling a template after "
                        "the first time, perhaps you used tf.Variable when you "
                        "meant tf.get_variable: %s",
                        list(
                            object_identity.ObjectIdentitySet(variables) -
                            object_identity.ObjectIdentitySet(vars_at_start)))
            else:
                self._variables_created = True
            return result
        except Exception as exc:
            # Reraise the exception, but append the original definition to the
            # trace.
            args = exc.args
            if not args:
                arg0 = ""
            else:
                arg0 = args[0]
            trace = "".join(
                _skip_common_stack_elements(self._stacktrace,
                                            traceback.format_stack()))
            arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
            new_args = [arg0]
            new_args.extend(args[1:])
            exc.args = tuple(new_args)
            raise
def _get_resource_inputs(op):
    """Returns an iterable of resources touched by this `op`."""
    reads = object_identity.ObjectIdentitySet()
    writes = object_identity.ObjectIdentitySet()
    for t in op.inputs:
        if t.dtype == dtypes_module.resource:
            if utils.op_writes_to_resource(t, op):
                writes.add(t)
            else:
                reads.add(t)
    saturated = False
    while not saturated:
        saturated = True
        for key in _acd_resource_resolvers_registry.list():
            # Resolvers should return true if they are updating the list of
            # resource_inputs.
            # TODO(srbs): An alternate would be to just compare the old and new set
            # but that may not be as fast.
            updated = _acd_resource_resolvers_registry.lookup(key)(op, reads,
                                                                   writes)
            if updated:
                # Conservatively remove any resources from `reads` that are also writes.
                reads = reads.difference(writes)
            saturated = saturated and not updated

    # Note: A resource handle that is not written to is treated as read-only. We
    # don't have a special way of denoting an unused resource.
    return ([(t, ResourceType.READ_ONLY)
             for t in reads] + [(t, ResourceType.READ_WRITE) for t in writes])
def get_reachable_from_inputs(inputs, targets=None):
    """Returns the set of tensors/ops reachable from `inputs`.

  Stops if all targets have been found (target is optional).

  Only valid in Symbolic mode, not Eager mode.

  Args:
    inputs: List of tensors.
    targets: List of tensors.

  Returns:
    A set of tensors reachable from the inputs (includes the inputs themselves).
  """
    inputs = nest.flatten(inputs, expand_composites=True)
    reachable = object_identity.ObjectIdentitySet(inputs)
    if targets:
        remaining_targets = object_identity.ObjectIdentitySet(
            nest.flatten(targets))
    queue = inputs[:]

    while queue:
        x = queue.pop()
        if isinstance(x, tuple(_user_convertible_tensor_types)):
            # Can't find consumers of user-specific types.
            continue

        if isinstance(x, ops.Operation):
            outputs = x.outputs[:] or []
            outputs += x._control_outputs  # pylint: disable=protected-access
        elif isinstance(x, variables.Variable):
            try:
                outputs = [x.op]
            except AttributeError:
                # Variables can be created in an Eager context.
                outputs = []
        elif tensor_util.is_tensor(x):
            outputs = x.consumers()
        else:
            if not isinstance(x, str):
                raise TypeError(
                    'Expected Operation, Variable, or Tensor, got ' + str(x))

        for y in outputs:
            if y not in reachable:
                reachable.add(y)
                if targets:
                    remaining_targets.discard(y)
                queue.insert(0, y)

        if targets and not remaining_targets:
            return reachable

    return reachable
示例#6
0
    def testDifference(self):
        class Element(object):
            pass

        a = Element()
        b = Element()
        c = Element()
        set1 = object_identity.ObjectIdentitySet([a, b])
        set2 = object_identity.ObjectIdentitySet([b, c])
        diff_set = set1.difference(set2)
        self.assertIn(a, diff_set)
        self.assertNotIn(b, diff_set)
        self.assertNotIn(c, diff_set)
示例#7
0
def _construct_concrete_function(func, output_graph_def,
                                 converted_input_indices):
  """Constructs a concrete function from the `output_graph_def`.

  Args:
    func: ConcreteFunction
    output_graph_def: GraphDef proto.
    converted_input_indices: Set of integers of input indices that were
      converted to constants.

  Returns:
    ConcreteFunction.
  """
  # Create a ConcreteFunction from the new GraphDef.
  input_tensors = func.graph.internal_captures
  converted_inputs = object_identity.ObjectIdentitySet(
      [input_tensors[index] for index in converted_input_indices])
  not_converted_inputs = [
      tensor for tensor in func.inputs if tensor not in converted_inputs]
  not_converted_inputs_map = {
      tensor.name: tensor for tensor in not_converted_inputs
  }

  new_input_names = [tensor.name for tensor in not_converted_inputs]
  new_output_names = [tensor.name for tensor in func.outputs]
  new_func = wrap_function.function_from_graph_def(output_graph_def,
                                                   new_input_names,
                                                   new_output_names)

  # Manually propagate shape for input tensors where the shape is not correctly
  # propagated. Scalars shapes are lost when wrapping the function.
  for input_tensor in new_func.inputs:
    input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape)
  return new_func
 def testDiscard(self):
     a = object()
     b = object()
     set1 = object_identity.ObjectIdentitySet([a, b])
     set1.discard(a)
     self.assertIn(b, set1)
     self.assertNotIn(a, set1)
示例#9
0
    def test_model(self):
        model = build_fba_matting()
        model.compile(
            optimizer='sgd', loss=['mse', None, None, None],
            run_eagerly=test_utils.should_run_eagerly())
        model.fit(
            [
                np.random.random((2, 240, 240, 3)).astype(np.uint8),
                np.random.random((2, 240, 240, 2)).astype(np.uint8),
                np.random.random((2, 240, 240, 6)).astype(np.uint8),
            ],
            [
                np.random.random((2, 240, 240, 7)).astype(np.float32),
                np.random.random((2, 240, 240, 1)).astype(np.float32),
                np.random.random((2, 240, 240, 3)).astype(np.float32),
                np.random.random((2, 240, 240, 3)).astype(np.float32)
            ],
            epochs=1, batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
示例#10
0
def extract_outputs_from_subclassing_model(model, output_dict, input_names,
                                           output_names, input_sigature):
    from tensorflow.python.keras.saving import saving_utils as _saving_utils
    from tensorflow.python.util import object_identity
    from ._graph_cvt import convert_variables_to_constants_v2 as _convert_to_constants

    function = _saving_utils.trace_model_call(model, input_sigature)
    concrete_func = function.get_concrete_function()
    for k_, v_ in concrete_func.structured_outputs.items():
        output_names.extend([ts_.name for ts_ in v_.op.outputs])
    output_dict.update(
        build_layer_outputs(model, concrete_func.graph, concrete_func.outputs))
    graph_def, converted_input_indices = _convert_to_constants(
        concrete_func, lower_control_flow=True)
    input_tensors = concrete_func.graph.internal_captures
    converted_inputs = object_identity.ObjectIdentitySet(
        [input_tensors[index] for index in converted_input_indices])
    input_names.extend([
        tensor.name for tensor in concrete_func.inputs
        if tensor not in converted_inputs
    ])

    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(graph_def, name='')

    return tf_graph
示例#11
0
def validate_and_slice_inputs(names_to_saveables):
    """Returns the variables and names that will be used for a Saver.

  Args:
    names_to_saveables: A dict (k, v) where k is the name of an operation and
       v is an operation to save or a BaseSaverBuilder.Saver.

  Returns:
    A list of SaveableObjects.

  Raises:
    TypeError: If any of the keys are not strings or any of the
      values are not one of Tensor or Variable or a trackable operation.
    ValueError: If the same operation is given in more than one value
      (this also applies to slices of SlicedVariables).
  """
    if not isinstance(names_to_saveables, dict):
        names_to_saveables = op_list_to_dict(names_to_saveables)

    saveables = []
    seen_ops = object_identity.ObjectIdentitySet()
    for name, op in sorted(
            names_to_saveables.items(),
            # Avoid comparing ops, sort only by name.
            key=lambda x: x[0]):
        for converted_saveable_object in saveable_objects_for_op(op, name):
            _add_saveable(saveables, seen_ops, converted_saveable_object)
    return saveables
def get_read_only_resource_input_indices_graph(func_graph):
  """Returns sorted list of read-only resource indices in func_graph.inputs."""
  result = []
  # A cache to store the read only resource inputs of an Op.
  # Operation -> ObjectIdentitySet of resource handles.
  op_read_only_resource_inputs = {}
  for input_index, t in enumerate(func_graph.inputs):
    if t.dtype != dtypes.resource:
      continue
    read_only = True
    for op in t.consumers():
      if op in op_read_only_resource_inputs:
        if t not in op_read_only_resource_inputs[op]:
          read_only = False
          break
      else:
        indices = _get_read_only_resource_input_indices_op(op)
        op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet(
            [op.inputs[i] for i in indices])
        if t not in op_read_only_resource_inputs[op]:
          read_only = False
          break
    if read_only:
      result.append(input_index)
  return result
示例#13
0
    def test_model(self):
        num_classes = 5
        model = build_deeplab_v3_plus(classes=num_classes,
                                      bone_arch='resnet_50',
                                      bone_init='imagenet',
                                      bone_train=False,
                                      aspp_filters=8,
                                      aspp_stride=16,
                                      low_filters=16,
                                      decoder_filters=4)
        model.compile(optimizer='sgd',
                      loss='sparse_categorical_crossentropy',
                      run_eagerly=test_utils.should_run_eagerly())
        model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8),
                  np.random.randint(0, num_classes, (2, 224, 224)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
示例#14
0
def _Inputs(op, xs):
    """Returns the inputs of op, crossing closure boundaries where necessary.

  Args:
    op: Operation
    xs: list of Tensors we are differentiating w.r.t.

  Returns:
    A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
    is in a FuncGraph and has captured inputs.
  """
    tensors = object_identity.ObjectIdentitySet(xs)
    if _IsFunction(op.graph):  # pylint: disable=protected-access
        inputs = []
        for t in op.inputs:
            # If we're differentiating w.r.t. `t`, do not attempt to traverse through
            # it to a captured value. The algorithm needs to "see" `t` in this case,
            # even if it's a function input for a captured value, whereas usually we'd
            # like to traverse through these closures as if the captured value was the
            # direct input to op.
            if t not in tensors:
                t = _MaybeCaptured(t)
            inputs.append(t)
        return inputs
    else:
        return op.inputs
示例#15
0
 def __init__(self,
              record_initial_resource_uses=False,
              record_uses_of_resource_ids=None):
   self._returned_tensors = object_identity.ObjectIdentitySet()
   self.ops_which_must_run = set()
   self.record_initial_resource_uses = record_initial_resource_uses
   self.record_uses_of_resource_ids = record_uses_of_resource_ids
示例#16
0
    def _add_checkpoint_values_check(self, trackable_objects,
                                     object_graph_proto):
        """Determines which objects have checkpoint values and saves to the proto.

    Args:
      trackable_objects: A list of all trackable objects.
      object_graph_proto: A `TrackableObjectGraph` proto.
    """
        # Trackable -> set of all trackables that depend on it (the "parents").
        # If a trackable has checkpoint values, then all of the parents can be
        # marked as having checkpoint values.
        parents = object_identity.ObjectIdentityDictionary()
        checkpointed_trackables = object_identity.ObjectIdentitySet()

        # First pass: build dictionary of parent objects and initial set of
        # checkpointed trackables.
        for trackable, object_proto in zip(trackable_objects,
                                           object_graph_proto.nodes):
            if (object_proto.attributes or object_proto.slot_variables
                    or object_proto.HasField("registered_saver")):
                checkpointed_trackables.add(trackable)
            for child_proto in object_proto.children:
                child = trackable_objects[child_proto.node_id]
                if child not in parents:
                    parents[child] = object_identity.ObjectIdentitySet()
                parents[child].add(trackable)

        # Second pass: add all connected parents to set of checkpointed trackables.
        to_visit = object_identity.ObjectIdentitySet()
        to_visit.update(checkpointed_trackables)

        while to_visit:
            trackable = to_visit.pop()
            if trackable not in parents:
                # Some trackables may not have parents (e.g. slot variables).
                continue
            current_parents = parents.pop(trackable)
            checkpointed_trackables.update(current_parents)
            for parent in current_parents:
                if parent in parents:
                    to_visit.add(parent)

        for node_id, trackable in enumerate(trackable_objects):
            object_graph_proto.nodes[
                node_id].has_checkpoint_values.value = bool(
                    trackable in checkpointed_trackables)
示例#17
0
 def _get_feeds(self, unfed_input_keys: Iterable[str]) -> Iterable[tf.Tensor]:
   """Returns set of tensors that will be fed."""
   result = object_identity.ObjectIdentitySet(self._func_graph.inputs)
   for input_key in unfed_input_keys:
     unfed_input_components = _get_component_tensors(
         self._structured_inputs[input_key])
     result = result.difference(unfed_input_components)
   return result
示例#18
0
def test_weights_equals_deduplicated_parameter_dict(model):
    """
    Checks GPflux's `model.trainable_weights` elements equals deduplicated
    GPflow's `gpflow.utilities.parameter_dict(model)`.
    """
    # We filter out the parameters of type ResourceVariable.
    # They have been added to the model by the `add_metric` call in the layer.
    parameters = [
        p for p in parameter_dict(model).values()
        if not isinstance(p, ResourceVariable)
    ]
    variables = map(lambda p: p.unconstrained_variable, parameters)
    deduplicate_variables = object_identity.ObjectIdentitySet(variables)

    weights = model.trainable_weights
    assert len(weights) == len(deduplicate_variables)

    weights_set = object_identity.ObjectIdentitySet(weights)
    assert weights_set == deduplicate_variables
示例#19
0
def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
                 op_outputs, add_sources):
    """Walk a Graph and capture the subgraph between init_tensor and sources.

  Note: This function mutates visited_ops and op_outputs.

  Arguments:
    init_tensor:  A Tensor or Operation where the subgraph terminates.
    sources:  A set of Tensors where subgraph extraction should stop.
    disallowed_placeholders: An optional set of ops which may not appear in the
      lifted graph. Defaults to all placeholders.
    visited_ops: A set of operations which were visited in a prior pass.
    op_outputs: A defaultdict containing the outputs of an op which are to be
      copied into the new subgraph.
    add_sources: A boolean indicating whether placeholders which are not in
      sources should be allowed.

  Returns:
    The set of placeholders upon which init_tensor depends and are not in
    sources.

  Raises:
    UnliftableError: if init_tensor depends on a placeholder which is not in
      sources and add_sources is False.
  """
    ops_to_visit = [_as_operation(init_tensor)]
    extra_sources = object_identity.ObjectIdentitySet()
    while ops_to_visit:
        op = ops_to_visit.pop()
        if op in visited_ops:
            continue
        visited_ops.add(op)

        should_raise = False
        if disallowed_placeholders is not None and op in disallowed_placeholders:
            should_raise = True
        elif op.type == "Placeholder":
            if disallowed_placeholders is None and not add_sources:
                should_raise = True
            extra_sources.update(op.outputs)

        if should_raise:
            raise UnliftableError(
                "Unable to lift tensor %s because it depends transitively on "
                "placeholder %s via at least one path, e.g.: %s" %
                (repr(init_tensor), repr(op),
                 _path_from(op, init_tensor, sources)))
        for inp in graph_inputs(op):
            op_outputs[inp].add(op)
            if inp not in visited_ops and inp not in (sources
                                                      or extra_sources):
                ops_to_visit.append(inp)

    return extra_sources
    def _get_feeds(self, unfed_input_keys):
        """Returns set of tensors that will be fed."""
        if self._is_finalized:
            return self._feeds

        result = object_identity.ObjectIdentitySet(self._func_graph.inputs)
        for input_key in unfed_input_keys:
            unfed_input_components = self._get_component_tensors(
                self._structured_inputs[input_key])
            result = result.difference(unfed_input_components)
        return result
示例#21
0
def _add_elements_to_collection(elements, collection_list):
    if context.executing_eagerly():
        raise RuntimeError(
            'Using collections from Layers not supported in Eager '
            'mode. Tried to add %s to %s' % (elements, collection_list))
    elements = nest.flatten(elements)
    collection_list = nest.flatten(collection_list)
    for name in collection_list:
        collection = ops.get_collection_ref(name)
        collection_set = object_identity.ObjectIdentitySet(collection)
        for element in elements:
            if element not in collection_set:
                collection.append(element)
示例#22
0
def count_params(weights):
  """Count the total number of scalars composing the weights.

  Arguments:
      weights: An iterable containing the weights on which to compute params

  Returns:
      The total number of scalars composing the weights
  """
  return int(
      sum(
          np.prod(p.shape.as_list())
          for p in object_identity.ObjectIdentitySet(weights)))
示例#23
0
def count_params(weights):
    """Count the total number of scalars composing the weights.

  Arguments:
      weights: An iterable containing the weights on which to compute params

  Returns:
      The total number of scalars composing the weights
  """
    unique_weights = object_identity.ObjectIdentitySet(weights)
    weight_shapes = [w.shape.as_list() for w in unique_weights]
    standardized_weight_shapes = [[0 if w_i is None else w_i for w_i in w]
                                  for w in weight_shapes]
    return int(sum(np.prod(p) for p in standardized_weight_shapes))
示例#24
0
def count_trainable_params(model):
    """
  Count the number of trainable parameters of tf.keras model
  Args
      model: tf.keras model
  return
      Total number ot trainable parameters
  """
    weights = model.trainable_weights
    total_trainable_params = int(
        sum(
            np.prod(p.shape)
            for p in object_identity.ObjectIdentitySet(weights)))
    return total_trainable_params
def get_read_write_resource_inputs(op):
  """Returns a tuple of resource reads, writes in op.inputs.

  Args:
    op: Operation

  Returns:
    A 2-tuple of ObjectIdentitySets, the first entry containing read-only
    resource handles and the second containing read-write resource handles in
    `op.inputs`.
  """
  reads = object_identity.ObjectIdentitySet()
  writes = object_identity.ObjectIdentitySet()

  if op.type in RESOURCE_READ_OPS:
    # Add all resource inputs to `reads` and return.
    reads.update(t for t in op.inputs if t.dtype == dtypes.resource)
    return (reads, writes)

  try:
    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
  except ValueError:
    # Attr was not set. Add all resource inputs to `writes` and return.
    writes.update(t for t in op.inputs if t.dtype == dtypes.resource)
    return (reads, writes)

  read_only_index = 0
  for i, t in enumerate(op.inputs):
    if op.inputs[i].dtype != dtypes.resource:
      continue
    if (read_only_index < len(read_only_input_indices) and
        i == read_only_input_indices[read_only_index]):
      reads.add(op.inputs[i])
      read_only_index += 1
    else:
      writes.add(op.inputs[i])
  return (reads, writes)
    def get_dependent_input_output_keys(self, input_keys, exclude_output_keys):
        """Determine inputs needed to get outputs excluding exclude_output_keys.

    Args:
      input_keys: A collection of all input keys available to supply to the
        SavedModel.
      exclude_output_keys: A collection of output keys returned by the
        SavedModel that should be excluded.

    Returns:
      A pair of:
        required_input_keys: A subset of the input features to this SavedModel
          that are required to compute the set of output features excluding
          `exclude_output_keys`. It is sorted to be deterministic.
        output_keys: The set of output features excluding `exclude_output_keys`.
          It is sorted to be deterministic.

    """
        # Assert inputs being fed and outputs being excluded are part of the
        # SavedModel.
        if set(input_keys).difference(self._structured_inputs.keys()):
            raise ValueError(
                'Input tensor names contained tensors not in graph: {}'.format(
                    input_keys))

        if set(exclude_output_keys).difference(
                self._structured_outputs.keys()):
            raise ValueError(
                'Excluded outputs contained keys not in graph: {}'.format(
                    exclude_output_keys))

        output_keys = (set(
            self._structured_outputs.keys()).difference(exclude_output_keys))

        # Get all the input tensors that are required to evaluate output_keys.
        required_inputs = object_identity.ObjectIdentitySet()
        for key in output_keys:
            required_inputs.update(self._output_to_inputs_map[key])

        # Get all the input feature names that have atleast one component tensor in
        # required_inputs.
        required_input_keys = []
        for key, tensor in six.iteritems(self._structured_inputs):
            if any(x in required_inputs
                   for x in self._get_component_tensors(tensor)):
                required_input_keys.append(key)

        return sorted(required_input_keys), sorted(output_keys)
示例#27
0
    def test_model(self):
        model = build_uper_net(classes=2, bone_arch='swin_tiny_224', bone_init='imagenet', bone_train=False)
        model.compile(optimizer='sgd', loss='mse', run_eagerly=test_utils.should_run_eagerly())
        model.fit(
            np.random.random((2, 240, 240, 3)).astype(np.uint8),
            np.random.random((2, 240, 240, 2)).astype(np.float32),
            epochs=1, batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
def _get_resource_inputs(op):
  """Returns an iterable of resources touched by this `op`."""
  resource_inputs = object_identity.ObjectIdentitySet(
      t for t in op.inputs if t.dtype == dtypes_module.resource)
  saturated = False
  while not saturated:
    saturated = True
    for key in _acd_resource_resolvers_registry.list():
      # Resolvers should return true if they are updating the list of
      # resource_inputs.
      # TODO(srbs): An alternate would be to just compare the old and new set
      # but that may not be as fast.
      updated = _acd_resource_resolvers_registry.lookup(key)(op,
                                                             resource_inputs)
      saturated = saturated and not updated
  return resource_inputs
示例#29
0
def filter_empty_layer_containers(layer_list):
  """Filter out empty Layer-like containers and uniquify."""
  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
  existing = object_identity.ObjectIdentitySet()
  to_visit = layer_list[::-1]
  while to_visit:
    obj = to_visit.pop()
    if obj in existing:
      continue
    existing.add(obj)
    if is_layer(obj):
      yield obj
    else:
      sub_layers = getattr(obj, "layers", None) or []

      # Trackable data structures will not show up in ".layers" lists, but
      # the layers they contain will.
      to_visit.extend(sub_layers[::-1])
示例#30
0
def filter_empty_layer_containers(layer_list):
  """Filter out empty Layer-like containers and uniquify."""
  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
  existing = object_identity.ObjectIdentitySet()
  to_visit = layer_list[::-1]
  filtered = []
  while to_visit:
    obj = to_visit.pop()
    if obj in existing:
      continue
    existing.add(obj)
    if is_layer(obj):
      filtered.append(obj)
    elif hasattr(obj, "layers"):
      # Trackable data structures will not show up in ".layers" lists, but
      # the layers they contain will.
      to_visit.extend(obj.layers[::-1])
  return filtered