예제 #1
0
def test_set_module_by_node_name__for_nested_module():
    model = TestModel()
    new_module = ReLU()
    set_module_by_node_name(model,
                            'TestModel/Sequential[layer2]/Sequential[layer1]',
                            new_module)
    assert new_module == get_module_by_node_name(
        model, 'TestModel/Sequential[layer2]/ReLU[layer1]')
예제 #2
0
def test_set_module_by_node_name__for_non_nested_module():
    model = TestModel()
    new_module = ReLU()
    set_module_by_node_name(model, 'TestModel/BatchNorm2d[bn1]', new_module)
    assert new_module == get_module_by_node_name(model, 'TestModel/ReLU[bn1]')