Example #1
0
import chainer

import chainerx

from chainerx_tests import array_utils
from chainerx_tests import dtype_utils
from chainerx_tests import op_utils

n_step_lstm_dtypes_valid = dtype_utils._permutate_dtype_mapping([
    # Floats.
    (('float16', ), ()),
    (('float32', ), ()),
    (('float64', ), ()),
])


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(chainer.testing.product([
    chainer.testing.from_pytest_parameterize(
        'n_layers,hidden_size,input_size,batches', [
            (2, 2, 1, (1, 1, 1)),
            (2, 2, 3, (3, 2, 1)),
            (3, 8, 4, (4, 2, 1)),
            (4, 12, 4, (4, 3, 2)),
        ]),
    chainer.testing.from_pytest_parameterize('in_dtypes, out_dtype',
                                             n_step_lstm_dtypes_valid)
])))
class TestNStepLstm(op_utils.ChainerOpTest):
    def setup(self):
        self.check_forward_options.update({'rtol': 1e-2, 'atol': 1e-2})
Example #2
0
from chainerx_tests import dtype_utils
from chainerx_tests import op_utils


_loss_shapes = [
    (2, 2),
    (3, 3, 3),
    (5, 5, 5),
    (4, 1, 2, 4),
]


_in_out_loss_dtypes = dtype_utils._permutate_dtype_mapping([
    (('float16', 'float16'), 'float16'),
    (('float32', 'float32'), 'float32'),
    (('float64', 'float64'), 'float64'),
    (('float32', 'float16'), 'float32'),
    (('float64', 'float16'), 'float64'),
    (('float64', 'float32'), 'float64'),
])


class LossBase(op_utils.ChainerOpTest):

    def setup(self):
        super().setup()
        in_dtype1, in_dtype2 = self.in_dtypes
        if in_dtype1 == 'float16' or in_dtype2 == 'float16':
            self.check_forward_options.update({'rtol': 5e-3, 'atol': 5e-3})
            self.check_backward_options.update({'rtol': 1e-2, 'atol': 5e-3})
            self.check_double_backward_options.update(
                {'rtol': 1e-2, 'atol': 3e-1})
Example #3
0
in_out_dtypes_math_binary_functions = dtype_utils._permutate_dtype_mapping([
    # integer mixed
    (('int8', 'int16'), 'float32'),
    (('int8', 'int32'), 'float32'),
    (('int8', 'int64'), 'float32'),
    (('int8', 'uint8'), 'float32'),
    (('int16', 'int32'), 'float32'),
    (('int16', 'int64'), 'float32'),
    (('int16', 'uint8'), 'float32'),
    (('int32', 'int64'), 'float32'),
    (('int32', 'uint8'), 'float32'),
    (('int64', 'uint8'), 'float32'),
    # integer float mixed
    (('int8', 'float16'), 'float16'),
    (('int8', 'float32'), 'float32'),
    (('int8', 'float64'), 'float64'),
    (('int16', 'float16'), 'float16'),
    (('int16', 'float32'), 'float32'),
    (('int16', 'float64'), 'float64'),
    (('int32', 'float16'), 'float16'),
    (('int32', 'float32'), 'float32'),
    (('int32', 'float64'), 'float64'),
    (('int64', 'float16'), 'float16'),
    (('int64', 'float32'), 'float32'),
    (('int64', 'float64'), 'float64'),
    (('uint8', 'float16'), 'float16'),
    (('uint8', 'float32'), 'float32'),
    (('uint8', 'float64'), 'float64'),
    # float mixed
    (('float16', 'float32'), 'float32'),
    (('float16', 'float64'), 'float64'),
    (('float32', 'float64'), 'float64'),
])
Example #4
0
from chainerx_tests import op_utils


_in_out_dtypes_bitwise = dtype_utils._permutate_dtype_mapping([
    # Same dtypes
    (('bool_', 'bool_'), 'bool_'),
    (('int8', 'int8'), 'int8'),
    (('int16', 'int16'), 'int16'),
    (('int32', 'int32'), 'int32'),
    (('int64', 'int64'), 'int64'),
    (('uint8', 'uint8'), 'uint8'),
    # Mixed dtypes
    (('bool_', 'int8'), 'int8'),
    (('bool_', 'int16'), 'int16'),
    (('bool_', 'int32'), 'int32'),
    (('bool_', 'int64'), 'int64'),
    (('bool_', 'uint8'), 'uint8'),
    (('int8', 'int16'), 'int16'),
    (('int8', 'int32'), 'int32'),
    (('int8', 'int64'), 'int64'),
    (('int8', 'uint8'), 'int16'),
    (('int16', 'int32'), 'int32'),
    (('int16', 'int64'), 'int64'),
    (('int16', 'uint8'), 'int16'),
    (('int32', 'int64'), 'int64'),
    (('int32', 'uint8'), 'int32'),
    (('int64', 'uint8'), 'int64'),
])

_in_out_dtypes_inplace_bitwise_invalid = [
    (('bool_', 'int8'), 'int8'),
Example #5
0
import chainerx

from chainerx_tests import array_utils
from chainerx_tests import dtype_utils
from chainerx_tests import op_utils


# A special parameter object used to represent an unspecified argument.
class Unspecified(object):
    pass


lstm_dtypes_valid = dtype_utils._permutate_dtype_mapping([
    # Floats.
    (('float16', 'float16'), ('float16', 'float16')),
    (('float32', 'float32'), ('float32', 'float32')),
    (('float64', 'float64'), ('float64', 'float64')),
])

lstm_dtypes_invalid = dtype_utils._permutate_dtype_mapping([
    # Bools.
    (('bool_', 'bool_'), ('bool_', 'bool_')),
    # Floats.
    (('float32', 'float16'), ('float32', 'float16')),
    (('float64', 'float16'), ('float64', 'float16')),
    (('float64', 'float32'), ('float64', 'float32')),
    # Signed ints.
    (('int8', 'int8'), ('int8', 'int8')),
    (('int8', 'int16'), ('int8', 'int16')),
    (('int8', 'int32'), ('int8', 'int32')),
    (('int8', 'int64'), ('int8', 'int64')),