def test_nested_model(self): class Block(torch.nn.Module): def __init__(self): super(Block, self).__init__() self.conv = torch.nn.Conv2d(1, 2, 3) self.bn = torch.nn.BatchNorm2d(1) class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.layers = torch.nn.Sequential( Block(), torch.nn.Sequential( Block(), torch.nn.BatchNorm2d(1) ), torch.nn.BatchNorm2d(1) ) model = Model() opts = TestGroupNormConversion.create_opts() models.replace_bn(model, opts) assert isinstance(model.layers[0].conv, torch.nn.Conv2d) and \ isinstance(model.layers[0].bn, torch.nn.GroupNorm) and \ isinstance(model.layers[1][0].conv, torch.nn.Conv2d) and \ isinstance(model.layers[1][0].bn, torch.nn.GroupNorm) and \ isinstance(model.layers[2], torch.nn.GroupNorm)
def test_single_element_model(self): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.bn = torch.nn.BatchNorm2d(1) model = Model() group_norm = TestGroupNormConversion.create_norm() models.replace_bn(model, group_norm) assert isinstance(model.bn, torch.nn.GroupNorm)
def test_sequential_model(self): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.layers = torch.nn.Sequential(torch.nn.Conv2d(1, 2, 3), torch.nn.BatchNorm2d(1)) model = Model() group_norm = TestGroupNormConversion.create_norm() models.replace_bn(model, group_norm) assert isinstance(model.layers[1], torch.nn.GroupNorm) and isinstance( model.layers[0], torch.nn.Conv2d)