def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope: temp_axes = core.Axes([axis] + list( labeled_tensor.axes.remove(axis.name).values())) transposed = core.transpose(labeled_tensor, temp_axes.keys()) indexed = core.LabeledTensor( array_ops.gather(transposed.tensor, indexer), temp_axes) return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope: temp_axes = core.Axes( [axis] + list(labeled_tensor.axes.remove(axis.name).values())) transposed = core.transpose(labeled_tensor, temp_axes.keys()) indexed = core.LabeledTensor( array_ops.gather(transposed.tensor, indexer), temp_axes) return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
def test_reverse(self): axis_order = ['w', 'x', 'y', 'z'] lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order) actual = core.impose_axis_order(lt, axis_order[::-1]) expected = core.transpose(lt, axis_order[::-1]) self.assertLabeledTensorsEqual(expected, actual) lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3]) actual = core.impose_axis_order(lt, axis_order[::-1]) expected = core.transpose(lt, ['y', 'x', 'w']) self.assertLabeledTensorsEqual(expected, actual)
def test_reverse(self): axis_order = ['w', 'x', 'y', 'z'] lt = core.LabeledTensor( array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order) actual = core.impose_axis_order(lt, axis_order[::-1]) expected = core.transpose(lt, axis_order[::-1]) self.assertLabeledTensorsEqual(expected, actual) lt = core.LabeledTensor( array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3]) actual = core.impose_axis_order(lt, axis_order[::-1]) expected = core.transpose(lt, ['y', 'x', 'w']) self.assertLabeledTensorsEqual(expected, actual)
def test_default_axis_order(self): transpose_lt = core.transpose(self.original_lt) golden_lt = core.LabeledTensor( array_ops.transpose(self.tensor, [3, 2, 1, 0]), list(reversed(list(self.original_lt.axes.values())))) self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test(self): transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x', 'probs']) golden_lt = core.LabeledTensor(tf.transpose(self.tensor, [2, 1, 0, 3]), [self.a2, self.a1, self.a0, self.a3]) self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_default_axis_order(self): transpose_lt = core.transpose(self.original_lt) golden_lt = core.LabeledTensor( tf.transpose(self.tensor, [3, 2, 1, 0]), list(reversed(list(self.original_lt.axes.values())))) self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_scope(self): axis_order = ['w', 'x', 'y', 'z'] lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order) expected = core.transpose(lt, axis_order[::-1]) with core.axis_order_scope(axis_order[::-1]): actual = core.impose_axis_order(lt) self.assertLabeledTensorsEqual(expected, actual)
def test(self): transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x', 'probs']) golden_lt = core.LabeledTensor( tf.transpose(self.tensor, [2, 1, 0, 3]), [self.a2, self.a1, self.a0, self.a3]) self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_scope(self): axis_order = ['w', 'x', 'y', 'z'] lt = core.LabeledTensor( array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order) expected = core.transpose(lt, axis_order[::-1]) with core.axis_order_scope(axis_order[::-1]): actual = core.impose_axis_order(lt) self.assertLabeledTensorsEqual(expected, actual)
def test_matrix_matrix(self): xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y']) yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z']) matmul_lt = ops.matmul(xy_lt, yz_lt) golden_lt = core.LabeledTensor( tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z']) self.assertLabeledTensorsEqual(matmul_lt, golden_lt) transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1]) matmul_lt = ops.matmul(xy_lt, transpose(yz_lt)) self.assertLabeledTensorsEqual(matmul_lt, golden_lt) matmul_lt = ops.matmul(transpose(xy_lt), yz_lt) self.assertLabeledTensorsEqual(matmul_lt, golden_lt) matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt)) self.assertLabeledTensorsEqual(matmul_lt, golden_lt) matmul_lt = ops.matmul(yz_lt, xy_lt) self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
def test_invalid_input(self): with self.assertRaises(ValueError): core.transpose(self.original_lt, ['channel', 'x', 'probs']) with self.assertRaises(ValueError): core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])
def test_identity(self): transpose_lt = core.transpose(self.original_lt, self.original_lt.axes.keys()) golden_lt = self.original_lt self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_name(self): transpose_lt = core.transpose(self.original_lt, self.original_lt.axes.keys()) self.assertIn('lt_transpose', transpose_lt.name)
def test_transposed(self): green_transposed = core.transpose(self.green_lt, ['probs', 'channel', 'z', 'x']) with self.assertRaises(ValueError): ops.concat([self.red_lt, green_transposed], 'channel')