def test_dimshuffle_bprop(self, x, A, B): """ dimshuffle a 2d array and make sure bprop works """ # randomly initialize x_value = rng.uniform(-1, 1, x.axes) check_derivative(ng.axes_with_order(x, [B, A]), x, 0.001, x_value, atol=1e-3, rtol=1e-3)
def test_dimshuffle_fprop(self, x, A, B): """ dimshuffle a 2d array and make sure fprop works """ # compute convolution with graph output = ng.axes_with_order(x, [B, A]) assert output.axes == ng.make_axes([B, A]) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) with executor(output, x) as ex: result = ex(x_value) ng.testing.assert_allclose(result, x_value.T)
def test_idempotent_axes_c(): """ Test test axes transformations with autodiff, case c, with broadcast, slice, cast and dim-shuffle """ with ExecutorFactory() as ex: axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)]) result_axes = [ng.make_axis(length=axis.length) for axis in axes] # variable w = ng.variable(axes, initial_value=np.ones((3, 1))) # broadcast l / r, introducing dummy length 1 axes l = ng.broadcast(w, axes) r = ng.broadcast(w, axes) # slice axes_slice = [slice(None, None, None), slice(None, None, None)] l_sliced = ng.tensor_slice(l, axes_slice) r_sliced = ng.tensor_slice(r, axes_slice) # cast r r_sliced_casted = ng.cast_axes(r_sliced, axes) # perform add result = ng.add(l_sliced, r_sliced_casted) # cast / dimshuffle result = ng.cast_axes(result, result_axes) result = ng.axes_with_order(result, result_axes) # cost and grad cost = ng.sum(result, reduction_axes=result.axes) grad = ng.deriv(cost, w) grad_comp = ex.executor(grad) cost_comp = ex.executor(cost) cost_comp_ng = cost_comp() grad_comp_ng = grad_comp() grad_comp_np = np.ones((3, 1)) * 2. assert cost_comp_ng == 6.0 assert np.array_equal(grad_comp_ng, grad_comp_np)
def test_dimshuffle_op(): A = ng.make_axis().named('A') B = ng.make_axis().named('B') C = ng.make_axis().named('C') D = ng.make_axis().named('D') tests = [ { 'input_tensor': [ [ [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], ], [ [13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24], ], ], ], 'input_tensor_axes': (A, B, C, D), 'output_tensor_axes': (B, D, A, C), 'axes_lengths': { A: 1, B: 2, C: 3, D: 4 }, 'expected_result': [[ [[1, 5, 9]], [[2, 6, 10]], [[3, 7, 11]], [[4, 8, 12]], ], [[[13, 17, 21]], [[14, 18, 22]], [[15, 19, 23]], [[16, 20, 24]]]] }, { 'input_tensor': [[ [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], ], [ [13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24], ], ], [[ [25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36], ], [ [37, 38, 39, 40], [41, 42, 43, 44], [45, 46, 47, 48], ]]], 'input_tensor_axes': (A, B, C, D), 'output_tensor_axes': (B, D, A, C), 'axes_lengths': { A: 2, B: 2, C: 3, D: 4 }, 'expected_result': [[[ [1, 5, 9], [25, 29, 33], ], [ [2, 6, 10], [26, 30, 34], ], [ [3, 7, 11], [27, 31, 35], ], [ [4, 8, 12], [28, 32, 36], ]], [[ [13, 17, 21], [37, 41, 45], ], [ [14, 18, 22], [38, 42, 46], ], [[15, 19, 23], [39, 43, 47]], [ [16, 20, 24], [40, 44, 48], ]]] }, ] for test in tests: for axis, length in test['axes_lengths'].items(): axis.length = length input_tensor = ng.placeholder(test['input_tensor_axes']) input_tensor_value = np.array(test['input_tensor'], dtype=np.float32) # This list of operations should add a dimshuffle operation to the graph. a = ng.negative(input_tensor) b = ng.axes_with_order(a, test['output_tensor_axes']) c = ng.negative(b) with executor(c, input_tensor) as ex: out = ex(input_tensor_value) ng.testing.assert_allclose(out, test['expected_result'])
def reorder_spatial_axes(tensor, channel_axis, spatial_axes): """ Reorders the axes of the input tensor in preparation for a spatial op (i.e. convolution, deconvolution, or pooling). Arguments: tensor (TensorOp): The input tensor whose axes must be a subset of those specified in channel_axis, spatial_axes and a batch axis. Missing axes in tensor will be added. channel_axis (Axis, str): The axis or axis name to use as the "channel" axis type spatial_axes (tuple of Axis or str): Tuple of axis or axis names to use as the "depth", "height", and "width" axis types, in that order. Returns: tensor with 5 dimensions, ordered as "channel", "depth", "height", "width", "batch" Raises: IncompatibleAxesError: The tensors' axes are incompatible with spatial ops using the given axis types. """ if len(tensor.axes) > 5: raise IncompatibleAxesError( "spatial ops cannot have more than 5 axes, " "found {}".format(len(tensor.axes))) def expand_with_name(tensor, axis, index=0): if isinstance(axis, Axis): if axis in tensor.axes: return tensor, axis if (axis.length is not None) and (axis.length > 1): raise IncompatibleAxesError( "Cannot expand tensor to an axis with length > 1: {}" ", length={}".format(axis.name, axis.length)) axis.length = 1 else: if axis in tensor.axes.names: return tensor, tensor.axes.find_by_name(axis)[0] axis = ng.make_axis(name=axis, length=1) return ng.expand_dims(tensor, axis, index), axis def not_in(axes, ax): if isinstance(ax, string_types): return not_in(axes, ng.make_axis(name=ax)) return ax not in axes batch_axis = tensor.axes.batch_axis() if batch_axis is None: raise IncompatibleAxesError( 'Spatial ops require a batch axis, but none were found: ' '{}'.format(tensor.axes)) if all(not_in(tensor.axes, ax) for ax in spatial_axes): raise IncompatibleAxesError( "spatial_axes provided were {}, but none were found in the " "tensor: {}. All spatial ops require at least one spatial " "dimension.".format(spatial_axes, tensor.axes)) tensor, channel_axis = expand_with_name(tensor, channel_axis) spatial_axes = list(spatial_axes) for ii, ax in enumerate(spatial_axes): tensor, ax = expand_with_name(tensor, ax) spatial_axes[ii] = ax new_axes = channel_axis + ng.make_axes(spatial_axes) + batch_axis if tensor.axes.is_not_equal_set(new_axes): raise IncompatibleAxesError( "Found extra axes: " "{}".format(set(tensor.axes).difference(set(new_axes)))) return ng.axes_with_order(tensor, new_axes)
def test_fail_on_axis_reuse(self, x, A, B): with pytest.raises(ValueError): ng.axes_with_order(x, [A, B, B])
def test_fail_on_missing_and_extra_axis(self, x, A, C): with pytest.raises(ValueError): ng.axes_with_order(x, [A, C])
def test_fail_on_missing(self, x, B): with pytest.raises(ValueError): ng.axes_with_order(x, [B, B])