def test_set_module_by_node_name__for_nested_module(): model = TestModel() new_module = ReLU() set_module_by_node_name(model, 'TestModel/Sequential[layer2]/Sequential[layer1]', new_module) assert new_module == get_module_by_node_name( model, 'TestModel/Sequential[layer2]/ReLU[layer1]')
def test_set_module_by_node_name__for_non_nested_module(): model = TestModel() new_module = ReLU() set_module_by_node_name(model, 'TestModel/BatchNorm2d[bn1]', new_module) assert new_module == get_module_by_node_name(model, 'TestModel/ReLU[bn1]')