def setUp(self): d_7 = tensor_shape.Dimension(7) d_8 = tensor_shape.Dimension(8) p_rgb = ['red', 'green', 'blue'] p_range = range(7) self.i_8 = core.Axis('8', d_8) self.a0 = core.Axes([('d7', d_7)]) self.a1 = core.Axes([('d7', d_7)]) self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)]) self.a3 = core.Axes([('8', d_8), ('range', p_range)])
def test_bijection_with_labels(self): depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks))) rc = sugar.ReshapeCoder(['channel', 'mask'], [depth_axis, ('other', ['label'])]) encode_lt = rc.encode(self.masked_image_lt) golden_axes = core.Axes([ self.batch_axis, self.row_axis, self.column_axis, depth_axis, ('other', ['label']) ]) self.assertEqual(encode_lt.axes, golden_axes) decode_lt = rc.decode(encode_lt) self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
def rename_axis(labeled_tensor, existing_name, new_name, name=None): """Rename an axis of LabeledTensor. Args: labeled_tensor: The input tensor. existing_name: Name for an existing axis on the input. new_name: Desired replacement name. name: Optional op name. Returns: LabeledTensor with renamed axis. Raises: ValueError: If `existing_name` is not an axis on the input. """ with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope: if existing_name not in labeled_tensor.axes: raise ValueError('existing_name %r are not contained in the set of axis ' 'names %r on the input labeled tensor' % (existing_name, labeled_tensor.axes.keys())) new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value) return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
def select(labeled_tensor, selection, name=None): """Slice out a subset of the tensor. Args: labeled_tensor: The input tensor. selection: A dictionary mapping an axis name to a scalar, slice or list of values to select. Currently supports two types of selections: (a) Any number of scalar and/or slice selections. (b) Exactly one list selection, without any scalars or slices. name: Optional op name. Returns: The selection as a `LabeledTensor`. Raises: ValueError: If the tensor doesn't have an axis in the selection or if that axis lacks labels. KeyError: If any labels in a selection are not found in the original axis. NotImplementedError: If you attempt to combine a list selection with scalar selection or another list selection. """ with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope: labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) slices = {} indexers = {} for axis_name, value in selection.items(): if axis_name not in labeled_tensor.axes: raise ValueError( 'The tensor does not have an axis named %s. Its axes are: %r' % (axis_name, labeled_tensor.axes.keys())) axis = labeled_tensor.axes[axis_name] if axis.labels is None: raise ValueError( 'The axis named %s does not have labels. The axis is: %r' % (axis_name, axis)) if isinstance(value, slice): # TODO(shoyer): consider deprecating using slices in favor of lists if value.start is None: start = None else: start = axis.index(value.start) if value.stop is None: stop = None else: # For now, follow the pandas convention of making labeled slices # inclusive of both bounds. stop = axis.index(value.stop) + 1 if value.step is not None: raise NotImplementedError( 'slicing with a step is not yet supported') slices[axis_name] = slice(start, stop) # Needs to be after checking for slices, since slice objects claim to be # instances of collections.Hashable but hash() on them fails. elif isinstance(value, collections.Hashable): slices[axis_name] = axis.index(value) elif isinstance(value, list): if indexers: raise NotImplementedError( 'select does not yet support more than one list selection at ' 'the same time') indexer = [axis.index(v) for v in value] indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64) else: # If type checking is working properly, this shouldn't be possible. raise TypeError('cannot handle arbitrary types') if indexers and slices: raise NotImplementedError( 'select does not yet support combined scalar and list selection' ) # For now, handle array selection separately, because tf.gather_nd does # not support gradients yet. Later, using gather_nd will let us combine # these paths. if indexers: (axis_name, indexer), = indexers.items() axis = core.Axis(axis_name, selection[axis_name]) return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope) else: return core.slice_function(labeled_tensor, slices, name=scope)
def test_concat_unknown(self): red = core.Axis('rgb', None) green = core.Axis('rgb', None) self.assertEqual(core.concat_axes([red, green]), red)
def test_concat_different_names(self): red = core.Axis('red', ['red']) green = core.Axis('green', ['red']) with self.assertRaises(ValueError): core.concat_axes([red, green])
def test_concat_single(self): red = core.Axis('rgb', ['red']) self.assertEqual(core.concat_axes([red]), red)
def test_axis_value_input(self): axis = self.i_range for value in [range(7), list(range(7)), np.arange(7)]: self.assertEqual(axis, core.Axis(axis.name, value))
def test_axis_input(self): axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown] for axis in axes: self.assertEqual(axis, core.Axis(axis.name, axis.value))