示例#1
0
    def test_dimshuffle_bprop(self, x, A, B):
        """
        dimshuffle a 2d array and make sure bprop works
        """
        # randomly initialize
        x_value = rng.uniform(-1, 1, x.axes)

        check_derivative(ng.axes_with_order(x, [B, A]),
                         x,
                         0.001,
                         x_value,
                         atol=1e-3,
                         rtol=1e-3)
示例#2
0
    def test_dimshuffle_fprop(self, x, A, B):
        """
        dimshuffle a 2d array and make sure fprop works
        """
        # compute convolution with graph
        output = ng.axes_with_order(x, [B, A])

        assert output.axes == ng.make_axes([B, A])

        # randomly initialize
        x_value = rng.uniform(-1, 1, x.axes)

        with executor(output, x) as ex:
            result = ex(x_value)

        ng.testing.assert_allclose(result, x_value.T)
def test_idempotent_axes_c():
    """
    Test test axes transformations with autodiff, case c, with broadcast,
    slice, cast and dim-shuffle
    """
    with ExecutorFactory() as ex:
        axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])
        result_axes = [ng.make_axis(length=axis.length) for axis in axes]

        # variable
        w = ng.variable(axes, initial_value=np.ones((3, 1)))

        # broadcast l / r, introducing dummy length 1 axes
        l = ng.broadcast(w, axes)
        r = ng.broadcast(w, axes)

        # slice
        axes_slice = [slice(None, None, None), slice(None, None, None)]
        l_sliced = ng.tensor_slice(l, axes_slice)
        r_sliced = ng.tensor_slice(r, axes_slice)

        # cast r
        r_sliced_casted = ng.cast_axes(r_sliced, axes)

        # perform add
        result = ng.add(l_sliced, r_sliced_casted)

        # cast / dimshuffle
        result = ng.cast_axes(result, result_axes)
        result = ng.axes_with_order(result, result_axes)

        # cost and grad
        cost = ng.sum(result, reduction_axes=result.axes)
        grad = ng.deriv(cost, w)

        grad_comp = ex.executor(grad)
        cost_comp = ex.executor(cost)

        cost_comp_ng = cost_comp()
        grad_comp_ng = grad_comp()
        grad_comp_np = np.ones((3, 1)) * 2.
        assert cost_comp_ng == 6.0
        assert np.array_equal(grad_comp_ng, grad_comp_np)
def test_dimshuffle_op():
    A = ng.make_axis().named('A')
    B = ng.make_axis().named('B')
    C = ng.make_axis().named('C')
    D = ng.make_axis().named('D')

    tests = [
        {
            'input_tensor': [
                [
                    [
                        [1, 2, 3, 4],
                        [5, 6, 7, 8],
                        [9, 10, 11, 12],
                    ],
                    [
                        [13, 14, 15, 16],
                        [17, 18, 19, 20],
                        [21, 22, 23, 24],
                    ],
                ],
            ],
            'input_tensor_axes': (A, B, C, D),
            'output_tensor_axes': (B, D, A, C),
            'axes_lengths': {
                A: 1,
                B: 2,
                C: 3,
                D: 4
            },
            'expected_result': [[
                [[1, 5, 9]],
                [[2, 6, 10]],
                [[3, 7, 11]],
                [[4, 8, 12]],
            ], [[[13, 17, 21]], [[14, 18, 22]], [[15, 19, 23]], [[16, 20,
                                                                  24]]]]
        },
        {
            'input_tensor': [[
                [
                    [1, 2, 3, 4],
                    [5, 6, 7, 8],
                    [9, 10, 11, 12],
                ],
                [
                    [13, 14, 15, 16],
                    [17, 18, 19, 20],
                    [21, 22, 23, 24],
                ],
            ],
                             [[
                                 [25, 26, 27, 28],
                                 [29, 30, 31, 32],
                                 [33, 34, 35, 36],
                             ],
                              [
                                  [37, 38, 39, 40],
                                  [41, 42, 43, 44],
                                  [45, 46, 47, 48],
                              ]]],
            'input_tensor_axes': (A, B, C, D),
            'output_tensor_axes': (B, D, A, C),
            'axes_lengths': {
                A: 2,
                B: 2,
                C: 3,
                D: 4
            },
            'expected_result': [[[
                [1, 5, 9],
                [25, 29, 33],
            ], [
                [2, 6, 10],
                [26, 30, 34],
            ], [
                [3, 7, 11],
                [27, 31, 35],
            ], [
                [4, 8, 12],
                [28, 32, 36],
            ]],
                                [[
                                    [13, 17, 21],
                                    [37, 41, 45],
                                ], [
                                    [14, 18, 22],
                                    [38, 42, 46],
                                ], [[15, 19, 23], [39, 43, 47]],
                                 [
                                     [16, 20, 24],
                                     [40, 44, 48],
                                 ]]]
        },
    ]

    for test in tests:
        for axis, length in test['axes_lengths'].items():
            axis.length = length

        input_tensor = ng.placeholder(test['input_tensor_axes'])
        input_tensor_value = np.array(test['input_tensor'], dtype=np.float32)

        # This list of operations should add a dimshuffle operation to the graph.
        a = ng.negative(input_tensor)
        b = ng.axes_with_order(a, test['output_tensor_axes'])
        c = ng.negative(b)

        with executor(c, input_tensor) as ex:
            out = ex(input_tensor_value)
            ng.testing.assert_allclose(out, test['expected_result'])
示例#5
0
def reorder_spatial_axes(tensor, channel_axis, spatial_axes):
    """
    Reorders the axes of the input tensor in preparation for a spatial op (i.e. convolution,
    deconvolution, or pooling).

    Arguments:
        tensor (TensorOp): The input tensor whose axes must be a subset of those specified in
            channel_axis, spatial_axes and a batch axis. Missing axes in tensor will be added.
        channel_axis (Axis, str): The axis or axis name to use as the "channel" axis type
        spatial_axes (tuple of Axis or str): Tuple of axis or axis names to use as the "depth",
            "height", and "width" axis types, in that order.

    Returns:
        tensor with 5 dimensions, ordered as "channel", "depth", "height", "width", "batch"

    Raises:
        IncompatibleAxesError: The tensors' axes are incompatible with spatial ops using the
            given axis types.
    """

    if len(tensor.axes) > 5:
        raise IncompatibleAxesError(
            "spatial ops cannot have more than 5 axes, "
            "found {}".format(len(tensor.axes)))

    def expand_with_name(tensor, axis, index=0):
        if isinstance(axis, Axis):
            if axis in tensor.axes:
                return tensor, axis
            if (axis.length is not None) and (axis.length > 1):
                raise IncompatibleAxesError(
                    "Cannot expand tensor to an axis with length > 1: {}"
                    ", length={}".format(axis.name, axis.length))
            axis.length = 1
        else:
            if axis in tensor.axes.names:
                return tensor, tensor.axes.find_by_name(axis)[0]
            axis = ng.make_axis(name=axis, length=1)
        return ng.expand_dims(tensor, axis, index), axis

    def not_in(axes, ax):
        if isinstance(ax, string_types):
            return not_in(axes, ng.make_axis(name=ax))

        return ax not in axes

    batch_axis = tensor.axes.batch_axis()
    if batch_axis is None:
        raise IncompatibleAxesError(
            'Spatial ops require a batch axis, but none were found: '
            '{}'.format(tensor.axes))

    if all(not_in(tensor.axes, ax) for ax in spatial_axes):
        raise IncompatibleAxesError(
            "spatial_axes provided were {}, but none were found in the "
            "tensor: {}. All spatial ops require at least one spatial "
            "dimension.".format(spatial_axes, tensor.axes))

    tensor, channel_axis = expand_with_name(tensor, channel_axis)

    spatial_axes = list(spatial_axes)
    for ii, ax in enumerate(spatial_axes):
        tensor, ax = expand_with_name(tensor, ax)
        spatial_axes[ii] = ax

    new_axes = channel_axis + ng.make_axes(spatial_axes) + batch_axis
    if tensor.axes.is_not_equal_set(new_axes):
        raise IncompatibleAxesError(
            "Found extra axes: "
            "{}".format(set(tensor.axes).difference(set(new_axes))))

    return ng.axes_with_order(tensor, new_axes)
示例#6
0
 def test_fail_on_axis_reuse(self, x, A, B):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [A, B, B])
示例#7
0
 def test_fail_on_missing_and_extra_axis(self, x, A, C):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [A, C])
示例#8
0
 def test_fail_on_missing(self, x, B):
     with pytest.raises(ValueError):
         ng.axes_with_order(x, [B, B])