예제 #1
0
def test_set_module_by_node_name__for_nested_module():
    model = TestModel()
    new_module = ReLU()
    set_module_by_node_name(model,
                            'TestModel/Sequential[layer2]/Sequential[layer1]',
                            new_module)
    assert new_module == get_module_by_node_name(
        model, 'TestModel/Sequential[layer2]/ReLU[layer1]')
예제 #2
0
def test_apply_by_node_name():
    model = TestModel()
    node_name = 'TestModel/BatchNorm2d[bn1]'
    bn1 = get_module_by_node_name(model, node_name)
    bn1.weight.data.fill_(1)
    assert bn1.weight == 1
    apply_by_node_name(model, [node_name], command=lambda m: init.zeros_(m.weight))
    assert bn1.weight == 0
예제 #3
0
def test_get_module_by_node_name__for_nested_module():
    model = TestModel()
    assert get_module_by_node_name(
        model,
        'TestModel/Sequential[layer2]/Sequential[layer1]') == model.layer1
예제 #4
0
def test_get_module_by_node_name__for_non_nested_module():
    model = TestModel()
    assert get_module_by_node_name(model,
                                   'TestModel/BatchNorm2d[bn1]') == model.bn1
예제 #5
0
def test_set_module_by_node_name__for_non_nested_module():
    model = TestModel()
    new_module = ReLU()
    set_module_by_node_name(model, 'TestModel/BatchNorm2d[bn1]', new_module)
    assert new_module == get_module_by_node_name(model, 'TestModel/ReLU[bn1]')
def test_get_module_by_node_name__for_non_nested_module():
    model = ModelForTest()
    assert get_module_by_node_name(model, 'ModelForTest/BatchNorm2d[bn1]') == model.bn1