Beispiel #1
0
        else:
          (name, label) = axis_spec
          reshaped_axes.append((name, (label,)))

        shape.append(1)

    reshaped_tensor = array_ops.reshape(labeled_tensor.tensor, shape,
                                        name=scope)

    return LabeledTensor(reshaped_tensor, reshaped_axes)

# This should only be added to a graph collection once.
_AXIS_ORDER_KEY = ('__axis_order',)


@tc.returns(tc.Optional(tc.List(string_types)))
def get_axis_order():
  """Get the axis_order set by any containing axis_order_scope.

  Returns:
    List of strings giving an order to use for axis names, or None, if no axis
    order is set.
  """
  # By storing axis_order in the graph, we can ensure that axis_order_scope is
  # thread-safe.
  axis_order_list = ops.get_collection(_AXIS_ORDER_KEY)
  if axis_order_list:
    axis_order, = axis_order_list
  else:
    axis_order = None
  return axis_order
Beispiel #2
0
class Axes(collections.Mapping):
  """Axis names and indices for a tensor.

  It is an ordered mapping, with keys given by axis name and values given
  by Axis objets. Duplicate axis names are not allowed.
  """

  @tc.accepts(object, tc.List(AxisLike))
  def __init__(self, axes):
    """Construct an Axes.

    Args:
      axes: A list of Axis objects or (axis_name, axis_value) tuples.

    Raises:
      ValueError: If the user provides empty or duplicate axis names.
    """
    self._axes = collections.OrderedDict()

    for axis_data in axes:
      axis = as_axis(axis_data)

      name = axis.name
      if name in self._axes:
        raise ValueError('Duplicate axis name: %s' % name)

      self._axes[name] = axis

  def __iter__(self):
    return iter(self._axes)

  @tc.returns(string_types)
  def __repr__(self):
    # Axes([('x', Dimension(2)),
    #       ('y', ['a', 'b', 'c']),
    #       ('z', Dimension(4))])
    cls_name = type(self).__name__
    values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()]
    values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values)
    return '%s([%s])' % (cls_name, values_repr)

  @tc.returns(Axis)
  @tc.accepts(object, string_types)
  def __getitem__(self, name):
    return self._axes[name]

  @tc.returns(bool)
  def __contains__(self, name):
    return name in self._axes

  @tc.returns(int)
  def __len__(self):
    return len(self._axes)

  def __hash__(self):
    return hash(tuple(self.items()))

  @tc.accepts(object, string_types)
  def remove(self, axis_name):
    """Creates a new Axes object without the given axis."""
    if axis_name not in self:
      raise KeyError(axis_name)
    remaining_axes = [axis for axis in self.values() if axis.name != axis_name]
    return Axes(remaining_axes)
Beispiel #3
0
        axes_0 = labeled_tensors[0].axes
        for t in labeled_tensors:
            if t.axes != axes_0:
                raise ValueError('Non-identical axes. Expected %s but got %s' %
                                 (axes_0, t.axes))

        pack_op = array_ops.stack([t.tensor for t in labeled_tensors],
                                  axis=axis_position,
                                  name=scope)
        axes = list(axes_0.values())
        axes.insert(axis_position, new_axis)
        return core.LabeledTensor(pack_op, axes)


@tc.returns(tc.List(core.LabeledTensor))
@tc.accepts(core.LabeledTensorLike, tc.Optional(string_types),
            tc.Optional(string_types))
def unpack(labeled_tensor, axis_name=None, name=None):
    """Unpack the tensor.

  See tf.unpack.

  Args:
    labeled_tensor: The input tensor.
    axis_name: Optional name of axis to unpack. By default, the first axis is
      used.
    name: Optional op name.

  Returns:
    The list of unpacked LabeledTensors.