Exemple #1
0
    def _run_internal_graph(self, inputs, masks=None):
        """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
        # Note: masking support is relevant mainly for Keras.
        # It cannot be factored out without having the fully reimplement the network
        # calling logic on the Keras side. We choose to incorporate it in
        # GraphNetwork because 1) it may be useful to fully support in tf.layers in
        # the future and 2) Keras is a major user of GraphNetwork.  If you don't
        # use masking, it does not interfere with regular behavior at all and you
        # can ignore it.
        if masks is None:
            masks = [None for _ in range(len(inputs))]

        # Dictionary mapping reference tensors to tuples
        # (computed tensor, compute mask)
        # we assume a 1:1 mapping from tensor to mask
        # TODO(fchollet): raise exception when a `.compute_mask()` call
        # does not return a list the same size as `call`
        tensor_map = {}
        for x, y, mask in zip(self.inputs, inputs, masks):
            tensor_map[str(id(x))] = (y, mask)

        depth_keys = list(self._nodes_by_depth.keys())
        depth_keys.sort(reverse=True)
        for depth in depth_keys:
            nodes = self._nodes_by_depth[depth]
            for node in nodes:
                # This is always a single layer, never a list.
                layer = node.outbound_layer

                reference_input_tensors = node.input_tensors
                reference_output_tensors = node.output_tensors

                # If all previous input tensors are available in tensor_map,
                # then call node.inbound_layer on them.
                computed_data = []  # List of tuples (input, mask).
                for x in reference_input_tensors:
                    if str(id(x)) in tensor_map:
                        computed_data.append(tensor_map[str(id(x))])

                if len(computed_data) == len(reference_input_tensors):
                    # Call layer (reapplying ops to new inputs).
                    with ops.name_scope(layer.name):
                        if node.arguments:
                            kwargs = node.arguments
                        else:
                            kwargs = {}
                        if len(computed_data) == 1:
                            computed_tensor, computed_mask = computed_data[0]
                            # Ensure mask propagation if applicable.
                            if 'mask' in estimator_util.fn_args(layer.call):
                                if 'mask' not in kwargs:
                                    kwargs['mask'] = computed_mask

                            output_tensors = nest.flatten(
                                layer.call(computed_tensor, **kwargs))
                            if hasattr(layer, 'compute_mask'):
                                output_masks = nest.flatten(
                                    layer.compute_mask(computed_tensor,
                                                       computed_mask))
                            else:
                                output_masks = [
                                    None for _ in range(len(output_tensors))
                                ]
                            computed_tensors = [computed_tensor]
                            computed_masks = [computed_mask]
                        else:
                            computed_tensors = [x[0] for x in computed_data]
                            computed_masks = [x[1] for x in computed_data]
                            if 'mask' in estimator_util.fn_args(layer.call):
                                if 'mask' not in kwargs:
                                    kwargs['mask'] = computed_masks
                            output_tensors = nest.flatten(
                                layer.call(computed_tensors, **kwargs))
                            if hasattr(layer, 'compute_mask'):
                                output_masks = nest.flatten(
                                    layer.compute_mask(computed_tensors,
                                                       computed_masks))
                            else:
                                output_masks = [
                                    None for _ in range(len(output_tensors))
                                ]

                        # Apply activity regularizer if any:
                        if layer.activity_regularizer is not None:
                            regularization_losses = [
                                layer.activity_regularizer(x)
                                for x in computed_tensors
                            ]
                            layer.add_loss(regularization_losses,
                                           computed_tensors)

                    if context.in_graph_mode():
                        # Update model updates and losses:
                        # Keep track of updates that depend on the inputs
                        # (e.g. BN updates).
                        self.add_update(
                            layer.get_updates_for(computed_tensors), inputs)
                        # Keep track of unconditional updates (e.g. a counter).
                        self.add_update(layer.get_updates_for(None), None)
                        # Keep track of losses that depend on the inputs
                        # (e.g. activity regularizers).
                        self.add_loss(layer.get_losses_for(computed_tensors),
                                      inputs)
                        # Keep track of unconditional losses
                        # (e.g. weight regularizers).
                        self.add_loss(layer.get_losses_for(None), None)

                    # Update tensor_map.
                    for x, y, mask in zip(reference_output_tensors,
                                          output_tensors, output_masks):
                        tensor_map[str(id(x))] = (y, mask)

        output_tensors = []
        output_masks = []
        output_shapes = []
        for x in self.outputs:
            assert str(
                id(x)) in tensor_map, 'Could not compute output ' + str(x)
            tensor, mask = tensor_map[str(id(x))]
            output_shapes.append(layers_util.static_shape(x))
            output_tensors.append(tensor)
            output_masks.append(mask)

        if len(output_tensors) == 1:
            output_tensors = output_tensors[0]
            if output_shapes is not None:
                output_shapes = output_shapes[0]
            if output_masks is not None:
                output_masks = output_masks[0]

        if context.in_graph_mode():
            # Update cache;
            # keys are based on ids on input tensors and inputs masks.
            cache_key = (layers_util.object_list_uid(inputs) + '_' +
                         layers_util.object_list_uid(masks))
            self._output_tensor_cache[cache_key] = output_tensors
            if output_masks is not None:
                self._output_mask_cache[cache_key] = output_masks
            if output_shapes is not None:
                input_shapes = [layers_util.static_shape(x) for x in inputs]
                cache_key = layers_util.object_list_uid(input_shapes)
                self._output_shape_cache[cache_key] = output_shapes

        return output_tensors, output_masks
Exemple #2
0
  def _run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    # Note: masking support is relevant mainly for Keras.
    # It cannot be factored out without having the fully reimplement the network
    # calling logic on the Keras side. We choose to incorporate it in
    # GraphNetwork because 1) it may be useful to fully support in tf.layers in
    # the future and 2) Keras is a major user of GraphNetwork.  If you don't
    # use masking, it does not interfere with regular behavior at all and you
    # can ignore it.
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self._nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer

        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # Call layer (reapplying ops to new inputs).
          with ops.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              # Ensure mask propagation if applicable.
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask

              output_tensors = nest.flatten(
                  layer.call(computed_tensor, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensor, computed_mask))
              else:
                output_masks = [None for _ in range(len(output_tensors))]
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = nest.flatten(
                  layer.call(computed_tensors, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensors, computed_masks))
              else:
                output_masks = [None for _ in range(len(output_tensors))]

            # Apply activity regularizer if any:
            if layer.activity_regularizer is not None:
              regularization_losses = [
                  layer.activity_regularizer(x) for x in computed_tensors
              ]
              layer.add_loss(regularization_losses, computed_tensors)

          if context.in_graph_mode():
            # Update model updates and losses:
            # Keep track of updates that depend on the inputs
            # (e.g. BN updates).
            self.add_update(layer.get_updates_for(computed_tensors), inputs)
            # Keep track of unconditional updates (e.g. a counter).
            self.add_update(layer.get_updates_for(None), None)
            # Keep track of losses that depend on the inputs
            # (e.g. activity regularizers).
            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
            # Keep track of unconditional losses
            # (e.g. weight regularizers).
            self.add_loss(layer.get_losses_for(None), None)

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(layers_util.static_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      if output_shapes is not None:
        output_shapes = output_shapes[0]
      if output_masks is not None:
        output_masks = output_masks[0]

    if context.in_graph_mode():
      # Update cache;
      # keys are based on ids on input tensors and inputs masks.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      self._output_tensor_cache[cache_key] = output_tensors
      if output_masks is not None:
        self._output_mask_cache[cache_key] = output_masks
      if output_shapes is not None:
        input_shapes = [layers_util.static_shape(x) for x in inputs]
        cache_key = layers_util.object_list_uid(input_shapes)
        self._output_shape_cache[cache_key] = output_shapes

    return output_tensors, output_masks