def test_flat_tensor_dot_tensor():
    """
    Ensure that a flattened argument axis is not unflattend in the result.

    """
    H = ng.make_axis(2)
    W = ng.make_axis(7)
    C = ng.make_axis(3)
    K = ng.make_axis(11)

    axes_a = ng.make_axes([H, W, C])
    a = ng.constant(np.ones(axes_a.lengths), axes=axes_a)
    flat_a = ng.flatten_at(a, 2)

    axes_b = ng.make_axes([C, K])
    b = ng.constant(np.ones(axes_b.lengths), axes=axes_b)

    result = ng.dot(b, flat_a)

    with ExecutorFactory() as factory:
        result_fun = factory.executor(result)
        result_val = result_fun()

    result_correct = np.ones_like(result_val) * C.length
    ng.testing.assert_allclose(result_val, result_correct)
Exemple #2
0
def test_conv_flatten_deriv(n4_hw12_c3_5x5):
    """
    Test deriv of conv followed by flatten
    """
    cf = ConvParams(**n4_hw12_c3_5x5)

    axes_rsck = ng.make_axes([cf.ax_f[2], cf.ax_f[3], cf.ax_f[0], cf.ax_f[-1]])
    axes_rsck_prime = ng.make_axes([ng.make_axis(name=ax.name + 'p', length=ax.length)
                                    for ax in axes_rsck])
    axes_nmpqk = ng.make_axes([cf.ax_o[-1], cf.ax_o[1], cf.ax_o[2], cf.ax_o[3], cf.ax_o[0]])

    # broadcast input / filter axes
    input_var = ng.variable(cf.ax_i).named('input')
    input_val = np.ones(input_var.axes.lengths)

    filter_rsck_prime = ng.variable(axes_rsck_prime).named('filter')
    filter_var = filter_rsck_prime
    filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck).named('frsck')
    filter_trsck = ng.expand_dims(filter_rsck, cf.ax_f[1], 0).named('ftrsck')
    filter_ctrsk = ng.axes_with_order(filter_trsck, axes=cf.ax_f).named('ctrsk')

    # convolution
    output_kmpqn = ng.convolution(cf.conv_params, input_var, filter_ctrsk, axes=cf.ax_o)
    output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk)

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    output_npqk = ng.tensor_slice(output_nmpqk, out_slicing)

    output = ng.flatten_at(output_npqk, idx=1)

    # cost and grad
    cost = ng.sum(output, out_axes=())

    filter_val = np.ones(filter_var.axes.lengths)

    with ExecutorFactory() as factory:

        conv_comp = factory.executor(output, filter_var, input_var)
        grad_filter_num_comp = factory.numeric_derivative(cost, filter_var, 1.0, input_var)
        grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var)

        grad_input_num_comp = factory.numeric_derivative(cost, input_var, 1.0, filter_var)
        grad_input_sym_comp = factory.derivative(cost, input_var, filter_var)

        conv_val = conv_comp(filter_val, input_val)
        conv_val_num = np.empty_like(conv_val)
        conv_val_num.fill(np.prod(cf.ax_f.lengths[:-1]))
        ng.testing.assert_allclose(conv_val, conv_val_num)

        grad_filter_num_val = grad_filter_num_comp(filter_val, input_val)
        grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val)
        ng.testing.assert_allclose(grad_filter_num_val, grad_filter_sym_val)

        grad_input_num_val = grad_input_num_comp(input_val, filter_val)
        grad_input_sym_val = grad_input_sym_comp(input_val, filter_val)
        ng.testing.assert_allclose(grad_input_num_val, grad_input_sym_val)
Exemple #3
0
    def Reshape(self, tf_node, inputs):
        """
        Reshapes a tensor.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            tensor, shape, name
        """
        # TODO: currently only support constants and flatten to 1d and 2d
        # get inputs
        tensor, shape = inputs

        def get_flatten_idx(shape_i, shape_o):
            """
            check if flattening shape is valid
            Args:
                shape_i: input tensor shape
                shape_o: output flattend tensor shape

            Returns:
                None if flatten not valid, otherwise the flatten_at index
            """
            return None

        # get input and output shape
        shape_i = tensor.shape.lengths
        shape_o = tuple(shape.const.astype(int))
        if np.prod(shape_i) != np.prod(shape_o):
            raise ValueError("Total size of input and output dimension "
                             "mismatch.")

        if tensor.const is not None:
            # reshape const
            np_val = np.reshape(tensor.const, shape_o)
            return ng.constant(np_val,
                               make_pos_axes(np_val.shape)).named(tf_node.name)
        else:
            ndims_o = len(shape_o)
            if ndims_o != 1 and ndims_o != 2:
                raise NotImplementedError("Reshape can only support flatten"
                                          "to 1d or 2d.")
            if ndims_o == 1:
                tensor = ng.flatten(tensor)
            else:
                cumprods = list(np.cumprod(shape_i))
                flatten_at_idx = cumprods.index(shape_o[0]) + 1
                tensor = ng.flatten_at(tensor, flatten_at_idx)
            res = ng.cast_axes(tensor, make_pos_axes(shape_o))
            return res.named(tf_node.name)
Exemple #4
0
def Flatten(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Flatten the input tensor into a 2D matrix."""
    data = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)

    if axis < 0 or axis > len(data.axes):
        raise ValueError(
            'Flatten node (%s): %d is not a valid value for `axis`.',
            onnx_node.name, axis)

    return cast_to_pos_axes(ng.flatten_at(data, axis))
Exemple #5
0
def test_conv_flatten_deriv(transformer_factory):
    """
    Test deriv of conv followed by flatten
    """
    # set shape
    C, D, H, W, N = (3, 1, 28, 28, 8)
    C, T, R, S, K = (3, 1, 5, 5, 32)

    # i, f, o axes
    ax_i = ng.make_axes([ax.C, ax.D, ax.H, ax.W, ax.N])
    ax_f = ng.make_axes([ax.C, ax.T, ax.R, ax.S, ax.K])
    ax_o = ng.make_axes([
        ng.make_axis(32, roles=[ar.Channel]),
        ng.make_axis(1, roles=[ar.Depth]),
        ng.make_axis(24, roles=[ar.Height]),
        ng.make_axis(24, roles=[ar.Width]), ax.N
    ])
    ax_i.set_shape((C, D, H, W, N))
    ax_f.set_shape((C, T, R, S, K))
    params = dict(pad_d=0, pad_h=0, pad_w=0, str_d=1, str_h=1, str_w=1)
    axes_rsck = ng.make_axes([ax.R, ax.S, ax.C, ax.K])
    axes_rsck_prime = ng.make_axes(
        [ng.make_axis(l) for l in axes_rsck.lengths])

    # broadcast input / filter axes
    image = ng.constant(np.ones(ax_i.lengths), ax_i)
    filter = ng.variable(axes_rsck_prime, initial_value=np.ones((R, S, C, K)))
    filter_casted = ng.cast_axes(filter, axes_rsck)
    filter_casted = ng.expand_dims(filter_casted, ax.T, 0)
    filter_casted = ng.axes_with_order(filter_casted, axes=ax_f)

    # convolution
    output = ng.convolution(params, image, filter_casted, axes=ax_o)
    oC, oD, oH, oW, oN = output.axes
    output = ng.axes_with_order(output,
                                axes=ng.make_axes([oN, oD, oH, oW, oC]))

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    conv = ng.Slice(output, out_slicing)
    flatten = ng.flatten_at(conv, idx=1)

    # cost and grad
    cost = ng.sum(flatten, reduction_axes=flatten.axes)
    grad = ng.deriv(cost, filter)

    # compute
    conv_grad_comp = executor([conv, grad])
    conv_val, grad_val = conv_grad_comp()

    assert np.allclose(conv_val, np.zeros_like(conv_val) + 75.)
    assert np.allclose(grad_val, np.zeros_like(grad_val) + 4608.)
Exemple #6
0
def reshape_workaround(
        data, shape_out):  # type: (TensorOp, Sequence[int]) -> TensorOp
    """Limited workaround for tensor reshape operation."""
    shape_in = data.shape.lengths

    if np.prod(shape_in) != np.prod(shape_out):
        raise ValueError(
            'Total size of input (%d) and output (%d) dimension mismatch.',
            np.prod(shape_in), np.prod(shape_out))

    ndims_out = len(shape_out)
    if ndims_out == 1:
        tensor = ng.flatten(data)
    elif ndims_out == 2:
        cumprods = list(np.cumprod(shape_in))
        flatten_at_idx = cumprods.index(shape_out[0]) + 1
        tensor = ng.flatten_at(data, flatten_at_idx)
    else:
        raise NotImplementedError(
            'Reshape can only support flatten to 1d or 2d.')

    return ng.cast_axes(tensor, make_pos_axes(shape_out))
Exemple #7
0
def test_flatten_deriv():
    from ngraph.frontends.neon import ax
    np.random.seed(0)

    # set shape
    C, D, H, W, N = (3, 1, 28, 28, 8)  # image
    Y = 10

    ax.C.length = C
    ax.D.length = D
    ax.H.length = H
    ax.W.length = W
    ax.N.length = N
    ax.Y.length = Y

    # conv output
    conv = ng.placeholder(ng.make_axes([ax.N, ax.H, ax.W, ax.C]))

    # flatten
    flatten = ng.flatten_at(conv, idx=1)
    num_flatten = flatten.axes.lengths[1]
    flatten = ng.cast_axes(flatten,
                           ng.make_axes([ax.N, ng.make_axis(num_flatten)]))

    # fc
    fc_weights_axes = ng.make_axes([ng.make_axis(num_flatten), ax.Y])
    fc_weights = ng.constant(np.random.randn(num_flatten, Y),
                             axes=fc_weights_axes)
    flatten_casted = ng.cast_axes(
        flatten, ng.make_axes([flatten.axes[0], fc_weights_axes[0] - 1]))
    logits = ng.dot(flatten_casted, fc_weights)
    cost = ng.sum(logits, reduction_axes=logits.axes)

    delta = 0.001
    u = rng.uniform(.1, 5.0, conv.axes)
    check_derivative(cost, conv, delta, u, atol=1e-2, rtol=1e-2)
Exemple #8
0
def test_conv_flatten_deriv(transformer_factory):
    """
    Test deriv of conv followed by flatten
    """

    # set shape
    # NOTE: N must be >= 4 for GPU, but for CPU this could be decreased to
    # speed up the test
    N = 4
    C, D, H, W = (3, 1, 28, 28)
    T, R, S, K = (1, 5, 5, 8)

    params = dict(pad_d=0,
                  pad_h=0,
                  pad_w=0,
                  str_d=1,
                  str_h=1,
                  str_w=1,
                  dil_d=1,
                  dil_h=1,
                  dil_w=1)

    # i, f, o axes
    ax_i = ng.make_axes([ax.C, ax.D, ax.H, ax.W, ax.N])
    ax_f = ng.make_axes([ax.C, ax.T, ax.R, ax.S, ax.K])
    ax_o = ng.make_axes([
        ng.make_axis(roles=[ar.features_input]).named('C'),
        ng.make_axis(roles=[ar.features_0]).named('D'),
        ng.make_axis(roles=[ar.features_1]).named('H'),
        ng.make_axis(roles=[ar.features_2]).named('W'), ax.N
    ])

    ax_i.set_shape((C, D, H, W, N))
    ax_f.set_shape((C, T, R, S, K))
    ax_o.set_shape((K, D - T + 1, H - R + 1, W - S + 1, N))
    axes_rsck = ng.make_axes([ax.R, ax.S, ax.C, ax.K])
    axes_rsck_prime = ng.make_axes([
        ng.make_axis(axis.length).named(axis.name + 'p') for axis in axes_rsck
    ])
    axes_nmpqk = ng.make_axes([ax_o[-1], ax_o[1], ax_o[2], ax_o[3], ax_o[0]])

    # broadcast input / filter axes
    input_var = ng.variable(ax_i).named('input')
    input_var.input = True
    input_val = np.ones(input_var.axes.lengths)

    filter_rsck_prime = ng.variable(axes_rsck_prime)
    filter_var = filter_rsck_prime
    filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck)
    filter_trsck = ng.expand_dims(filter_rsck, ax.T, 0)
    filter_ctrsk = ng.axes_with_order(filter_trsck, axes=ax_f)

    # convolution
    output_kmpqn = ng.convolution(params, input_var, filter_ctrsk, axes=ax_o)
    output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk)

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    output_npqk = ng.tensor_slice(output_nmpqk, out_slicing)

    output = ng.flatten_at(output_npqk, idx=1)

    # cost and grad
    cost = ng.sum(output, out_axes=())

    filter_var.input = True
    filter_var.named('filter')
    filter_val = np.ones(filter_var.axes.lengths)

    with ExecutorFactory() as factory:

        conv_comp = factory.executor(output, filter_var, input_var)
        grad_filter_num_comp = factory.numeric_derivative(
            cost, filter_var, 1.0, input_var)
        grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var)

        grad_input_num_comp = factory.numeric_derivative(
            cost, input_var, 1.0, filter_var)
        grad_input_sym_comp = factory.derivative(cost, input_var, filter_var)

        conv_val = conv_comp(filter_val, input_val)
        conv_val_num = np.empty_like(conv_val)
        conv_val_num.fill(C * T * R * S)
        assert ng.testing.allclose(conv_val, conv_val_num)

        grad_filter_num_val = grad_filter_num_comp(filter_val, input_val)
        grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val)
        assert ng.testing.allclose(grad_filter_num_val, grad_filter_sym_val)

        grad_input_num_val = grad_input_num_comp(input_val, filter_val)
        grad_input_sym_val = grad_input_sym_comp(input_val, filter_val)
        assert ng.testing.allclose(grad_input_num_val, grad_input_sym_val)