Example #1
0
def template(in_order, in_shape, out_order, out_shape):
    op = Reshape(None,
                 in_order=in_order,
                 out_order=out_order,
                 out_shape=[out_shape[a] for a in out_order.axes])
    x = Variable([in_shape[a] for a in in_order.axes], in_order)
    y, = op(x)
    assert_shape(y, out_shape)
Example #2
0
def template(N=2, T=3, vocabulary_size=4, feature_size=5, order_x=OrderNT, order_w=OrderNC):
    x = Variable([N, T], OrderNT)
    w = Variable([feature_size, vocabulary_size], OrderNC)

    x.change_order(order_x)
    w.change_order(order_w)

    y, = Embedding(None)(x, w)

    assert_shape(y, AxisKeyDict([Axis.N, Axis.T, Axis.C], [N, T, feature_size]))
Example #3
0
def template_test_unary_operator(OperatorClass,
                                 operator_kwargs=None,
                                 test1d=True,
                                 test2d=True,
                                 test3d=True,
                                 test4d=True,
                                 axes=None,
                                 orders=None,
                                 shape_dict=None,
                                 expected_dict=None):
    """
    Test template for unary operator

    Args:
        OperatorClass: Target operator class
        operator_kwargs: Operator keyword arguments
        test1d: If True, test with 1D input tensor is ran
        test2d: If True, test with 2D input tensor is ran
        test3d: If True, test with 3D input tensor is ran
        test4d: If True, test with 4D input tensor is ran
        orders: Orders for test input variable. If :code:`None`, all combination of axes are tested.
        axes: If specified and :code:`orders` is not specified, all combination of axes in :code:`axes` are tested.
        shape_dict: Input variable's shape
        expected_dict: Expected shape of output variable
    """

    if operator_kwargs is None:
        operator_kwargs = {}

    if axes is None:
        axes = [Axis.N, Axis.H, Axis.W, Axis.C, Axis.T]

    if orders is None:
        orders = []
        for ndim, flag in {1: test1d, 2: test2d, 3: test3d, 4: test4d}.items():
            if not flag:
                continue

            for axis in permutations(axes, ndim):
                orders.append(Order(axis))

    if shape_dict is None:
        shape_dict = AxisKeyDict()
        for i, axis in enumerate(axes):
            shape_dict[axis] = i + 5

    for order in orders:
        x = Variable([shape_dict[a] for a in order.axes], order)
        y, = OperatorClass(None, **operator_kwargs)(x)
        assert_shape(y,
                     x.shape_dict if expected_dict is None else expected_dict)
Example #4
0
def template(in_order, in_shape, out_order, out_shape):
    op = ReinterpretAxis(None, in_order=in_order, out_order=out_order)
    x = Variable([in_shape[a] for a in in_order.axes], in_order)
    y, = op(x)
    assert_shape(y, out_shape)