예제 #1
0
# )
#
# The meaning of each field of `TorchNNModuleMetadata` is as follows:
#
# `cpp_default_constructor_args`: string that represents the required non-keyword arguments
#     for the C++ module constructor. For example, since `LinearOptions` expects two non-keyword
#     arguments `(in_features, out_features)`, the `cpp_default_constructor_args` for `Linear`
#     will be the string representation of any integer 2-tuple, such as "(3, 4)".
#     Note that the C++ module constructor must take the exact same number of non-keyword arguments
#     as the Python module constructor.
#
# `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor
#     attributes) of the Python module. If the module contains any submodule, the submodule's
#     attributes also need to be counted.
module_metadata_map = {
    'Conv1d': TorchNNModuleMetadata(),
    'Conv2d': TorchNNModuleMetadata(),
    'Conv3d': TorchNNModuleMetadata(),
    'ConvTranspose1d': TorchNNModuleMetadata(),
    'ConvTranspose2d': TorchNNModuleMetadata(),
    'ConvTranspose3d': TorchNNModuleMetadata(),
    'Unfold': TorchNNModuleMetadata(),
    'Fold': TorchNNModuleMetadata(),
    'MaxPool1d': TorchNNModuleMetadata(),
    'MaxPool2d': TorchNNModuleMetadata(),
    'MaxPool3d': TorchNNModuleMetadata(),
    'MaxUnpool1d': TorchNNModuleMetadata(),
    'MaxUnpool2d': TorchNNModuleMetadata(),
    'MaxUnpool3d': TorchNNModuleMetadata(),
    'AvgPool1d': TorchNNModuleMetadata(),
    'AvgPool2d': TorchNNModuleMetadata(),
예제 #2
0
}
"""

module_tests = [
    dict(
        module_name='SampleModule',
        desc='has_parity',
        constructor_args=(True, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        has_parity=True,
    ),
    dict(
        fullname='SampleModule_no_parity',
        constructor=lambda: SampleModule(False, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        has_parity=False,
    ),
]

torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata(
    cpp_default_constructor_args='(true)',
    num_attrs_recursive=20,
    cpp_sources=SAMPLE_MODULE_CPP_SOURCE,
    python_ignored_constructor_args=['has_parity'],
    python_ignored_attrs=['has_parity'],
)

torch.nn.SampleModule = SampleModule
예제 #3
0
#
# The meaning of each field of `TorchNNModuleMetadata` is as follows:
#
# `cpp_default_constructor_args`: string that represents the required non-keyword arguments
#     for the C++ module constructor. For example, since `LinearOptions` expects two non-keyword
#     arguments `(in_features, out_features)`, the `cpp_default_constructor_args` for `Linear`
#     will be the string representation of any integer 2-tuple, such as "(3, 4)".
#     Note that the C++ module constructor must take the exact same number of non-keyword arguments
#     as the Python module constructor.
#
# `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor
#     attributes) of the Python module. If the module contains any submodule, the submodule's
#     attributes also need to be counted.
module_metadata_map = {
    'Conv1d':
    TorchNNModuleMetadata(),
    'Conv2d':
    TorchNNModuleMetadata(),
    'Conv3d':
    TorchNNModuleMetadata(),
    'ConvTranspose1d':
    TorchNNModuleMetadata(),
    'ConvTranspose2d':
    TorchNNModuleMetadata(),
    'ConvTranspose3d':
    TorchNNModuleMetadata(),
    'Unfold':
    TorchNNModuleMetadata(),
    'Fold':
    TorchNNModuleMetadata(
        cpp_default_constructor_args="(3, 2)",
예제 #4
0
# `cpp_default_constructor_args`: string that represents the required non-keyword arguments
#     for the C++ module constructor. For example, since `LinearOptions` expects two non-keyword
#     arguments `(in_features, out_features)`, the `cpp_default_constructor_args` for `Linear`
#     will be the string representation of any integer 2-tuple, such as "(3, 4)".
#     Note that the C++ module constructor must take the exact same number of non-keyword arguments
#     as the Python module constructor.
#
# `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor
#     attributes) of the Python module. If the module contains any submodule, the submodule's
#     attributes also need to be counted.
#
# `python_legacy_constructor_args`: (optional) list of legacy Python constructor args that are
#     ignored in Python/C++ API parity test.
module_metadata_map = {
    'Conv1d':
    TorchNNModuleMetadata(),
    'Conv2d':
    TorchNNModuleMetadata(),
    'Conv3d':
    TorchNNModuleMetadata(),
    'ConvTranspose1d':
    TorchNNModuleMetadata(),
    'ConvTranspose2d':
    TorchNNModuleMetadata(),
    'ConvTranspose3d':
    TorchNNModuleMetadata(),
    'Unfold':
    TorchNNModuleMetadata(),
    'Fold':
    TorchNNModuleMetadata(
        cpp_default_constructor_args="(3, 2)",
예제 #5
0
#
# `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor
#     attributes, but excluding the attributes in `python_ignored_attrs`) of the Python module.
#     If the module contains any submodule, the submodule's attributes also need to be counted.
#
# `python_ignored_constructor_args`: (optional) list of Python constructor args that are
#     ignored in Python/C++ API parity test.
#
# `python_ignored_attrs`: (optional) list of Python module attributes (including parameters,
#     buffers and non-tensor attributes) that are ignored in Python/C++ API parity test.
#
# `python_optional_attribute_to_jit_type`: (optional) map between Python None-able module
#     attribute to its corresponding JIT type. For example, in `AvgPool2d`:
#     { "divisor_override": torch._C.OptionalType(torch._C.IntType.get()) }
module_metadata_map = {
    'Conv1d': TorchNNModuleMetadata(),
    'Conv2d': TorchNNModuleMetadata(),
    'Conv3d': TorchNNModuleMetadata(),
    'ConvTranspose1d': TorchNNModuleMetadata(),
    'ConvTranspose2d': TorchNNModuleMetadata(),
    'ConvTranspose3d': TorchNNModuleMetadata(),
    'Unfold': TorchNNModuleMetadata(),
    'Fold': TorchNNModuleMetadata(
        cpp_default_constructor_args="(3, 2)",
        num_attrs_recursive=5,
    ),
    'MaxPool1d': TorchNNModuleMetadata(
        cpp_default_constructor_args="(2)",
        num_attrs_recursive=5,
        python_ignored_constructor_args=['return_indices'],
        python_ignored_attrs=['return_indices'],
예제 #6
0
        desc='has_parity',
        constructor_args=(True, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        has_parity=True,
    ),
    dict(
        fullname='SampleModule_no_parity',
        constructor=lambda: SampleModule(False, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        has_parity=False,
    ),
]

torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata(
    cpp_default_constructor_args='(true)',
    num_attrs_recursive=18,
    options_args=[
        'has_submodule',
        'int_option',
        'double_option',
        'bool_option',
        'string_option',
        'tensor_option',
    ],
    cpp_sources=SAMPLE_MODULE_CPP_SOURCE,
)

torch.nn.SampleModule = SampleModule
예제 #7
0
}
}
"""

module_tests = [
    dict(
        module_name='SampleModule',
        constructor_args=(True, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        desc='has_parity',
        has_parity=True,
    ),
    dict(
        module_name='SampleModule',
        constructor_args=(False, True),
        cpp_constructor_args='(true)',
        input_size=(3, 4),
        desc='no_parity',
        has_parity=False,
    ),
]

torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata(
    cpp_default_constructor_args='(true)',
    num_attrs_recursive=6,
    cpp_sources=SAMPLE_MODULE_CPP_SOURCE,
)

torch.nn.SampleModule = SampleModule
예제 #8
0
#     Note that the C++ module constructor must take the exact same number of non-keyword arguments
#     as the Python module constructor.
#
# `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor
#     attributes) of the Python module. If the module contains any submodule, the submodule's
#     attributes also need to be counted.
#
# `python_legacy_constructor_args`: (optional) list of legacy Python constructor args that are
#     ignored in Python/C++ API parity test.
#
# `python_optional_attribute_to_jit_type`: (optional) map between Python None-able module
#     attribute to its corresponding JIT type. For example, in `AvgPool2d`:
#     { "divisor_override": torch._C.OptionalType(torch._C.IntType.get()) }
module_metadata_map = {
    'Conv1d':
    TorchNNModuleMetadata(),
    'Conv2d':
    TorchNNModuleMetadata(),
    'Conv3d':
    TorchNNModuleMetadata(),
    'ConvTranspose1d':
    TorchNNModuleMetadata(),
    'ConvTranspose2d':
    TorchNNModuleMetadata(),
    'ConvTranspose3d':
    TorchNNModuleMetadata(),
    'Unfold':
    TorchNNModuleMetadata(),
    'Fold':
    TorchNNModuleMetadata(
        cpp_default_constructor_args="(3, 2)",