Esempio n. 1
0
  def call(self, inputs, mask=None):
    """Call the model on new inputs.

    In this case `call` just reapplies
    all ops in the graph to the new inputs
    (e.g. build a new computational graph from the provided inputs).

    Arguments:
        inputs: A tensor or list of tensors.
        mask: A mask or list of masks. A mask can be
            either a tensor or None (no mask).

    Returns:
        A tensor if there is a single output, or
        a list of tensors if there are more than one outputs.
    """
    inputs = nest.flatten(inputs)
    if mask is None:
      masks = [None for _ in range(len(inputs))]
    else:
      masks = nest.flatten(mask)

    if context.in_graph_mode():
      # Try to retrieve cached outputs if the layer has already been called
      # on these exact inputs.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      if cache_key in self._output_tensor_cache:
        # Cache hit.
        return self._output_tensor_cache[cache_key]
    # Actually apply the network graph to the new inputs.
    outputs, _ = self._run_internal_graph(inputs, masks)
    return outputs
Esempio n. 2
0
    def call(self, inputs, mask=None):
        """Call the model on new inputs.

    In this case `call` just reapplies
    all ops in the graph to the new inputs
    (e.g. build a new computational graph from the provided inputs).

    Arguments:
        inputs: A tensor or list of tensors.
        mask: A mask or list of masks. A mask can be
            either a tensor or None (no mask).

    Returns:
        A tensor if there is a single output, or
        a list of tensors if there are more than one outputs.
    """
        inputs = nest.flatten(inputs)
        if mask is None:
            masks = [None for _ in range(len(inputs))]
        else:
            masks = nest.flatten(mask)

        if context.in_graph_mode():
            # Try to retrieve cached outputs if the layer has already been called
            # on these exact inputs.
            cache_key = (layers_util.object_list_uid(inputs) + '_' +
                         layers_util.object_list_uid(masks))
            if cache_key in self._output_tensor_cache:
                # Cache hit.
                return self._output_tensor_cache[cache_key]
        # Actually apply the network graph to the new inputs.
        outputs, _ = self._run_internal_graph(inputs, masks)
        return outputs
Esempio n. 3
0
    def get_updates_for(self, inputs=None):
        # If the wrapper modifies the inputs, use the modified inputs to
        # get the updates from the inner layer.
        inner_inputs = inputs
        if inputs is not None:
            uid = tf_layers_util.object_list_uid(inputs)
            if uid in self._input_map:
                inner_inputs = self._input_map[uid]

        updates = self.layer.get_updates_for(inner_inputs)
        updates += super(Wrapper, self).get_updates_for(inputs)
        return updates
Esempio n. 4
0
  def get_updates_for(self, inputs=None):
    # If the wrapper modifies the inputs, use the modified inputs to
    # get the updates from the inner layer.
    inner_inputs = inputs
    if inputs is not None:
      uid = tf_layers_util.object_list_uid(inputs)
      if uid in self._input_map:
        inner_inputs = self._input_map[uid]

    updates = self.layer.get_updates_for(inner_inputs)
    updates += super(Wrapper, self).get_updates_for(inputs)
    return updates
Esempio n. 5
0
  def call(self, inputs, training=None, mask=None):
    kwargs = {}
    if has_arg(self.layer.call, 'training'):
      kwargs['training'] = training
    uses_learning_phase = False  # pylint: disable=redefined-outer-name

    input_shape = K.int_shape(inputs)
    if input_shape[0]:
      # batch size matters, use rnn-based implementation
      def step(x, _):
        global uses_learning_phase  # pylint: disable=global-variable-undefined
        output = self.layer.call(x, **kwargs)
        if hasattr(output, '_uses_learning_phase'):
          uses_learning_phase = (output._uses_learning_phase or
                                 uses_learning_phase)
        return output, []

      _, outputs, _ = K.rnn(
          step,
          inputs,
          initial_states=[],
          unroll=False)
      y = outputs
    else:
      # No batch size specified, therefore the layer will be able
      # to process batches of any size.
      # We can go with reshape-based implementation for performance.
      input_length = input_shape[1]
      if not input_length:
        input_length = K.shape(inputs)[1]
      # Shape: (num_samples * timesteps, ...). And track the
      # transformation in self._input_map.
      input_uid = tf_layers_util.object_list_uid(inputs)
      inputs = K.reshape(inputs, (-1,) + input_shape[2:])
      self._input_map[input_uid] = inputs
      # (num_samples * timesteps, ...)
      y = self.layer.call(inputs, **kwargs)
      if hasattr(y, '_uses_learning_phase'):
        uses_learning_phase = y._uses_learning_phase
      # Shape: (num_samples, timesteps, ...)
      output_shape = self.compute_output_shape(input_shape).as_list()
      y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:]))

    # Apply activity regularizer if any:
    if (hasattr(self.layer, 'activity_regularizer') and
        self.layer.activity_regularizer is not None):
      regularization_loss = self.layer.activity_regularizer(y)
      self.add_loss(regularization_loss, inputs)

    if uses_learning_phase:
      y._uses_learning_phase = True
    return y
Esempio n. 6
0
    def call(self, inputs, training=None, mask=None):
        kwargs = {}
        if has_arg(self.layer.call, 'training'):
            kwargs['training'] = training
        uses_learning_phase = False  # pylint: disable=redefined-outer-name

        input_shape = K.int_shape(inputs)
        if input_shape[0]:
            # batch size matters, use rnn-based implementation
            def step(x, _):
                global uses_learning_phase  # pylint: disable=global-variable-undefined
                output = self.layer.call(x, **kwargs)
                if hasattr(output, '_uses_learning_phase'):
                    uses_learning_phase = (output._uses_learning_phase
                                           or uses_learning_phase)
                return output, []

            _, outputs, _ = K.rnn(step,
                                  inputs,
                                  initial_states=[],
                                  unroll=False)
            y = outputs
        else:
            # No batch size specified, therefore the layer will be able
            # to process batches of any size.
            # We can go with reshape-based implementation for performance.
            input_length = input_shape[1]
            if not input_length:
                input_length = array_ops.shape(inputs)[1]
            # Shape: (num_samples * timesteps, ...). And track the
            # transformation in self._input_map.
            input_uid = tf_layers_util.object_list_uid(inputs)
            inputs = array_ops.reshape(inputs, (-1, ) + input_shape[2:])
            self._input_map[input_uid] = inputs
            # (num_samples * timesteps, ...)
            y = self.layer.call(inputs, **kwargs)
            if hasattr(y, '_uses_learning_phase'):
                uses_learning_phase = y._uses_learning_phase
            # Shape: (num_samples, timesteps, ...)
            output_shape = self.compute_output_shape(input_shape).as_list()
            y = array_ops.reshape(y,
                                  (-1, input_length) + tuple(output_shape[2:]))

        # Apply activity regularizer if any:
        if (hasattr(self.layer, 'activity_regularizer')
                and self.layer.activity_regularizer is not None):
            regularization_loss = self.layer.activity_regularizer(y)
            self.add_loss(regularization_loss, inputs)

        if uses_learning_phase:
            y._uses_learning_phase = True
        return y
Esempio n. 7
0
    def _run_internal_graph(self, inputs, masks=None):
        """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
        # Note: masking support is relevant mainly for Keras.
        # It cannot be factored out without having the fully reimplement the network
        # calling logic on the Keras side. We choose to incorporate it in
        # GraphNetwork because 1) it may be useful to fully support in tf.layers in
        # the future and 2) Keras is a major user of GraphNetwork.  If you don't
        # use masking, it does not interfere with regular behavior at all and you
        # can ignore it.
        if masks is None:
            masks = [None for _ in range(len(inputs))]

        # Dictionary mapping reference tensors to tuples
        # (computed tensor, compute mask)
        # we assume a 1:1 mapping from tensor to mask
        # TODO(fchollet): raise exception when a `.compute_mask()` call
        # does not return a list the same size as `call`
        tensor_map = {}
        for x, y, mask in zip(self.inputs, inputs, masks):
            tensor_map[str(id(x))] = (y, mask)

        depth_keys = list(self._nodes_by_depth.keys())
        depth_keys.sort(reverse=True)
        for depth in depth_keys:
            nodes = self._nodes_by_depth[depth]
            for node in nodes:
                # This is always a single layer, never a list.
                layer = node.outbound_layer

                reference_input_tensors = node.input_tensors
                reference_output_tensors = node.output_tensors

                # If all previous input tensors are available in tensor_map,
                # then call node.inbound_layer on them.
                computed_data = []  # List of tuples (input, mask).
                for x in reference_input_tensors:
                    if str(id(x)) in tensor_map:
                        computed_data.append(tensor_map[str(id(x))])

                if len(computed_data) == len(reference_input_tensors):
                    # Call layer (reapplying ops to new inputs).
                    with ops.name_scope(layer.name):
                        if node.arguments:
                            kwargs = node.arguments
                        else:
                            kwargs = {}
                        if len(computed_data) == 1:
                            computed_tensor, computed_mask = computed_data[0]
                            # Ensure mask propagation if applicable.
                            if 'mask' in estimator_util.fn_args(layer.call):
                                if 'mask' not in kwargs:
                                    kwargs['mask'] = computed_mask

                            output_tensors = nest.flatten(
                                layer.call(computed_tensor, **kwargs))
                            if hasattr(layer, 'compute_mask'):
                                output_masks = nest.flatten(
                                    layer.compute_mask(computed_tensor,
                                                       computed_mask))
                            else:
                                output_masks = [
                                    None for _ in range(len(output_tensors))
                                ]
                            computed_tensors = [computed_tensor]
                            computed_masks = [computed_mask]
                        else:
                            computed_tensors = [x[0] for x in computed_data]
                            computed_masks = [x[1] for x in computed_data]
                            if 'mask' in estimator_util.fn_args(layer.call):
                                if 'mask' not in kwargs:
                                    kwargs['mask'] = computed_masks
                            output_tensors = nest.flatten(
                                layer.call(computed_tensors, **kwargs))
                            if hasattr(layer, 'compute_mask'):
                                output_masks = nest.flatten(
                                    layer.compute_mask(computed_tensors,
                                                       computed_masks))
                            else:
                                output_masks = [
                                    None for _ in range(len(output_tensors))
                                ]

                        # Apply activity regularizer if any:
                        if layer.activity_regularizer is not None:
                            regularization_losses = [
                                layer.activity_regularizer(x)
                                for x in computed_tensors
                            ]
                            layer.add_loss(regularization_losses,
                                           computed_tensors)

                    if context.in_graph_mode():
                        # Update model updates and losses:
                        # Keep track of updates that depend on the inputs
                        # (e.g. BN updates).
                        self.add_update(
                            layer.get_updates_for(computed_tensors), inputs)
                        # Keep track of unconditional updates (e.g. a counter).
                        self.add_update(layer.get_updates_for(None), None)
                        # Keep track of losses that depend on the inputs
                        # (e.g. activity regularizers).
                        self.add_loss(layer.get_losses_for(computed_tensors),
                                      inputs)
                        # Keep track of unconditional losses
                        # (e.g. weight regularizers).
                        self.add_loss(layer.get_losses_for(None), None)

                    # Update tensor_map.
                    for x, y, mask in zip(reference_output_tensors,
                                          output_tensors, output_masks):
                        tensor_map[str(id(x))] = (y, mask)

        output_tensors = []
        output_masks = []
        output_shapes = []
        for x in self.outputs:
            assert str(
                id(x)) in tensor_map, 'Could not compute output ' + str(x)
            tensor, mask = tensor_map[str(id(x))]
            output_shapes.append(layers_util.static_shape(x))
            output_tensors.append(tensor)
            output_masks.append(mask)

        if len(output_tensors) == 1:
            output_tensors = output_tensors[0]
            if output_shapes is not None:
                output_shapes = output_shapes[0]
            if output_masks is not None:
                output_masks = output_masks[0]

        if context.in_graph_mode():
            # Update cache;
            # keys are based on ids on input tensors and inputs masks.
            cache_key = (layers_util.object_list_uid(inputs) + '_' +
                         layers_util.object_list_uid(masks))
            self._output_tensor_cache[cache_key] = output_tensors
            if output_masks is not None:
                self._output_mask_cache[cache_key] = output_masks
            if output_shapes is not None:
                input_shapes = [layers_util.static_shape(x) for x in inputs]
                cache_key = layers_util.object_list_uid(input_shapes)
                self._output_shape_cache[cache_key] = output_shapes

        return output_tensors, output_masks
Esempio n. 8
0
    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shapes = []
            for shape in input_shape:
                if shape is not None:
                    input_shapes.append(
                        tuple(tensor_shape.TensorShape(shape).as_list()))
                else:
                    input_shapes.append(None)
        else:
            if input_shape is not None:
                input_shapes = [
                    tuple(tensor_shape.TensorShape(input_shape).as_list())
                ]
            else:
                input_shapes = [None]

        if len(input_shapes) != len(self._input_layers):
            raise ValueError('Invalid input_shape argument ' +
                             str(input_shape) + ': model has ' +
                             str(len(self._input_layers)) + ' tensor inputs.')

        cache_key = layers_util.object_list_uid(input_shapes)
        if cache_key not in self._output_shape_cache:
            # Cache miss. We have to run the network graph manually (recursive calls
            # to `compute_output_shape`).
            layers_to_output_shapes = {}
            for i in range(len(input_shapes)):
                layer = self._input_layers[i]
                input_shape = input_shapes[i]
                # It's an input layer: then `compute_output_shape` is identity,
                # and there is only one node and one tensor output.
                shape_key = layer.name + '_0_0'
                layers_to_output_shapes[shape_key] = input_shape

            depth_keys = list(self._nodes_by_depth.keys())
            depth_keys.sort(reverse=True)
            # Iterate over nodes, by depth level.
            if len(depth_keys) > 1:
                for depth in depth_keys:
                    nodes = self._nodes_by_depth[depth]
                    for node in nodes:
                        # This is always a single layer, never a list.
                        layer = node.outbound_layer
                        if layer in self._input_layers:
                            # We've already covered the input layers
                            # a few lines above.
                            continue
                        # Potentially redundant list,
                        # same size as node.input_tensors.
                        input_shapes = []
                        for j in range(len(node.inbound_layers)):
                            inbound_layer = node.inbound_layers[j]
                            node_index = node.node_indices[j]
                            tensor_index = node.tensor_indices[j]
                            shape_key = inbound_layer.name + '_%s_%s' % (
                                node_index, tensor_index)
                            input_shape = layers_to_output_shapes[shape_key]
                            input_shapes.append(input_shape)

                        if len(input_shapes) == 1:
                            output_shape = layer.compute_output_shape(
                                input_shapes[0])
                        else:
                            output_shape = layer.compute_output_shape(
                                input_shapes)
                        if isinstance(output_shape, list):
                            output_shapes = [
                                tuple(
                                    tensor_shape.TensorShape(shape).as_list())
                                for shape in output_shape
                            ]
                        else:
                            output_shapes = [
                                tuple(
                                    tensor_shape.TensorShape(
                                        output_shape).as_list())
                            ]

                        node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
                        for j in range(len(output_shapes)):
                            shape_key = layer.name + '_%s_%s' % (node_index, j)
                            layers_to_output_shapes[shape_key] = output_shapes[
                                j]

                # Read final output shapes from layers_to_output_shapes.
                output_shapes = []
                for i in range(len(self._output_layers)):
                    layer, node_index, tensor_index = self._output_coordinates[
                        i]
                    shape_key = layer.name + '_%s_%s' % (node_index,
                                                         tensor_index)
                    output_shapes.append(layers_to_output_shapes[shape_key])

                # Store in cache.
                self._output_shape_cache[cache_key] = output_shapes
        else:
            # Cache hit.
            output_shapes = self._output_shape_cache[cache_key]

        if isinstance(output_shapes, list):
            if len(output_shapes) == 1:
                return tensor_shape.TensorShape(output_shapes[0])
            else:
                return [
                    tensor_shape.TensorShape(shape) for shape in output_shapes
                ]
        else:
            return tensor_shape.TensorShape(output_shapes)
Esempio n. 9
0
  def _run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    # Note: masking support is relevant mainly for Keras.
    # It cannot be factored out without having the fully reimplement the network
    # calling logic on the Keras side. We choose to incorporate it in
    # GraphNetwork because 1) it may be useful to fully support in tf.layers in
    # the future and 2) Keras is a major user of GraphNetwork.  If you don't
    # use masking, it does not interfere with regular behavior at all and you
    # can ignore it.
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self._nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer

        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # Call layer (reapplying ops to new inputs).
          with ops.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              # Ensure mask propagation if applicable.
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask

              output_tensors = nest.flatten(
                  layer.call(computed_tensor, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensor, computed_mask))
              else:
                output_masks = [None for _ in range(len(output_tensors))]
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = nest.flatten(
                  layer.call(computed_tensors, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensors, computed_masks))
              else:
                output_masks = [None for _ in range(len(output_tensors))]

            # Apply activity regularizer if any:
            if layer.activity_regularizer is not None:
              regularization_losses = [
                  layer.activity_regularizer(x) for x in computed_tensors
              ]
              layer.add_loss(regularization_losses, computed_tensors)

          if context.in_graph_mode():
            # Update model updates and losses:
            # Keep track of updates that depend on the inputs
            # (e.g. BN updates).
            self.add_update(layer.get_updates_for(computed_tensors), inputs)
            # Keep track of unconditional updates (e.g. a counter).
            self.add_update(layer.get_updates_for(None), None)
            # Keep track of losses that depend on the inputs
            # (e.g. activity regularizers).
            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
            # Keep track of unconditional losses
            # (e.g. weight regularizers).
            self.add_loss(layer.get_losses_for(None), None)

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(layers_util.static_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      if output_shapes is not None:
        output_shapes = output_shapes[0]
      if output_masks is not None:
        output_masks = output_masks[0]

    if context.in_graph_mode():
      # Update cache;
      # keys are based on ids on input tensors and inputs masks.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      self._output_tensor_cache[cache_key] = output_tensors
      if output_masks is not None:
        self._output_mask_cache[cache_key] = output_masks
      if output_shapes is not None:
        input_shapes = [layers_util.static_shape(x) for x in inputs]
        cache_key = layers_util.object_list_uid(input_shapes)
        self._output_shape_cache[cache_key] = output_shapes

    return output_tensors, output_masks
Esempio n. 10
0
  def _compute_output_shape(self, input_shape):
    if isinstance(input_shape, list):
      input_shapes = []
      for shape in input_shape:
        if shape is not None:
          input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list()))
        else:
          input_shapes.append(None)
    else:
      if input_shape is not None:
        input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())]
      else:
        input_shapes = [None]

    if len(input_shapes) != len(self._input_layers):
      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
                       ': model has ' + str(len(self._input_layers)) +
                       ' tensor inputs.')

    cache_key = layers_util.object_list_uid(input_shapes)
    if cache_key not in self._output_shape_cache:
      # Cache miss. We have to run the network graph manually (recursive calls
      # to `_compute_output_shape`).
      layers_to_output_shapes = {}
      for i in range(len(input_shapes)):
        layer = self._input_layers[i]
        input_shape = input_shapes[i]
        # It's an input layer: then `_compute_output_shape` is identity,
        # and there is only one node and one tensor output.
        shape_key = layer.name + '_0_0'
        layers_to_output_shapes[shape_key] = input_shape

      depth_keys = list(self._nodes_by_depth.keys())
      depth_keys.sort(reverse=True)
      # Iterate over nodes, by depth level.
      if len(depth_keys) > 1:
        for depth in depth_keys:
          nodes = self._nodes_by_depth[depth]
          for node in nodes:
            # This is always a single layer, never a list.
            layer = node.outbound_layer
            if layer in self._input_layers:
              # We've already covered the input layers
              # a few lines above.
              continue
            # Potentially redundant list,
            # same size as node.input_tensors.
            input_shapes = []
            for j in range(len(node.inbound_layers)):
              inbound_layer = node.inbound_layers[j]
              node_index = node.node_indices[j]
              tensor_index = node.tensor_indices[j]
              shape_key = inbound_layer.name + '_%s_%s' % (node_index,
                                                           tensor_index)
              input_shape = layers_to_output_shapes[shape_key]
              input_shapes.append(input_shape)

            if len(input_shapes) == 1:
              output_shape = layer._compute_output_shape(input_shapes[0])  # pylint: disable=protected-access
            else:
              output_shape = layer._compute_output_shape(input_shapes)  # pylint: disable=protected-access
            if isinstance(output_shape, list):
              output_shapes = [
                  tuple(tensor_shape.TensorShape(shape).as_list())
                  for shape in output_shape
              ]
            else:
              output_shapes = [
                  tuple(tensor_shape.TensorShape(output_shape).as_list())
              ]

            node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
            for j in range(len(output_shapes)):
              shape_key = layer.name + '_%s_%s' % (node_index, j)
              layers_to_output_shapes[shape_key] = output_shapes[j]

        # Read final output shapes from layers_to_output_shapes.
        output_shapes = []
        for i in range(len(self._output_layers)):
          layer, node_index, tensor_index = self._output_coordinates[i]
          shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
          output_shapes.append(layers_to_output_shapes[shape_key])

        # Store in cache.
        self._output_shape_cache[cache_key] = output_shapes
    else:
      # Cache hit.
      output_shapes = self._output_shape_cache[cache_key]

    if isinstance(output_shapes, list):
      if len(output_shapes) == 1:
        return tensor_shape.TensorShape(output_shapes[0])
      else:
        return [tensor_shape.TensorShape(shape) for shape in output_shapes]
    else:
      return tensor_shape.TensorShape(output_shapes)