else: (name, label) = axis_spec reshaped_axes.append((name, (label,))) shape.append(1) reshaped_tensor = array_ops.reshape(labeled_tensor.tensor, shape, name=scope) return LabeledTensor(reshaped_tensor, reshaped_axes) # This should only be added to a graph collection once. _AXIS_ORDER_KEY = ('__axis_order',) @tc.returns(tc.Optional(tc.List(string_types))) def get_axis_order(): """Get the axis_order set by any containing axis_order_scope. Returns: List of strings giving an order to use for axis names, or None, if no axis order is set. """ # By storing axis_order in the graph, we can ensure that axis_order_scope is # thread-safe. axis_order_list = ops.get_collection(_AXIS_ORDER_KEY) if axis_order_list: axis_order, = axis_order_list else: axis_order = None return axis_order
class Axes(collections.Mapping): """Axis names and indices for a tensor. It is an ordered mapping, with keys given by axis name and values given by Axis objets. Duplicate axis names are not allowed. """ @tc.accepts(object, tc.List(AxisLike)) def __init__(self, axes): """Construct an Axes. Args: axes: A list of Axis objects or (axis_name, axis_value) tuples. Raises: ValueError: If the user provides empty or duplicate axis names. """ self._axes = collections.OrderedDict() for axis_data in axes: axis = as_axis(axis_data) name = axis.name if name in self._axes: raise ValueError('Duplicate axis name: %s' % name) self._axes[name] = axis def __iter__(self): return iter(self._axes) @tc.returns(string_types) def __repr__(self): # Axes([('x', Dimension(2)), # ('y', ['a', 'b', 'c']), # ('z', Dimension(4))]) cls_name = type(self).__name__ values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()] values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values) return '%s([%s])' % (cls_name, values_repr) @tc.returns(Axis) @tc.accepts(object, string_types) def __getitem__(self, name): return self._axes[name] @tc.returns(bool) def __contains__(self, name): return name in self._axes @tc.returns(int) def __len__(self): return len(self._axes) def __hash__(self): return hash(tuple(self.items())) @tc.accepts(object, string_types) def remove(self, axis_name): """Creates a new Axes object without the given axis.""" if axis_name not in self: raise KeyError(axis_name) remaining_axes = [axis for axis in self.values() if axis.name != axis_name] return Axes(remaining_axes)
axes_0 = labeled_tensors[0].axes for t in labeled_tensors: if t.axes != axes_0: raise ValueError('Non-identical axes. Expected %s but got %s' % (axes_0, t.axes)) pack_op = array_ops.stack([t.tensor for t in labeled_tensors], axis=axis_position, name=scope) axes = list(axes_0.values()) axes.insert(axis_position, new_axis) return core.LabeledTensor(pack_op, axes) @tc.returns(tc.List(core.LabeledTensor)) @tc.accepts(core.LabeledTensorLike, tc.Optional(string_types), tc.Optional(string_types)) def unpack(labeled_tensor, axis_name=None, name=None): """Unpack the tensor. See tf.unpack. Args: labeled_tensor: The input tensor. axis_name: Optional name of axis to unpack. By default, the first axis is used. name: Optional op name. Returns: The list of unpacked LabeledTensors.