def test_set_module_by_node_name__for_nested_module(): model = TestModel() new_module = ReLU() set_module_by_node_name(model, 'TestModel/Sequential[layer2]/Sequential[layer1]', new_module) assert new_module == get_module_by_node_name( model, 'TestModel/Sequential[layer2]/ReLU[layer1]')
def test_apply_by_node_name(): model = TestModel() node_name = 'TestModel/BatchNorm2d[bn1]' bn1 = get_module_by_node_name(model, node_name) bn1.weight.data.fill_(1) assert bn1.weight == 1 apply_by_node_name(model, [node_name], command=lambda m: init.zeros_(m.weight)) assert bn1.weight == 0
def test_get_module_by_node_name__for_nested_module(): model = TestModel() assert get_module_by_node_name( model, 'TestModel/Sequential[layer2]/Sequential[layer1]') == model.layer1
def test_get_module_by_node_name__for_non_nested_module(): model = TestModel() assert get_module_by_node_name(model, 'TestModel/BatchNorm2d[bn1]') == model.bn1
def test_set_module_by_node_name__for_non_nested_module(): model = TestModel() new_module = ReLU() set_module_by_node_name(model, 'TestModel/BatchNorm2d[bn1]', new_module) assert new_module == get_module_by_node_name(model, 'TestModel/ReLU[bn1]')
def test_get_module_by_node_name__for_non_nested_module(): model = ModelForTest() assert get_module_by_node_name(model, 'ModelForTest/BatchNorm2d[bn1]') == model.bn1