def test_verify_cross_layer_scaling(self): # Get trained MNIST model torch.manual_seed(10) model = MyModel() # Call API model = model.eval() random_input = torch.rand(2, 10, 24, 24) # model.features[0].bias.data = torch.rand(64) baseline_output = model(random_input).detach().numpy() CrossLayerScaling.scale_cls_set_with_conv_layers( (model.conv1, model.conv2)) output_after_scaling = model(random_input).detach().numpy() range_conv1_after_scaling = np.amax(np.abs( model.conv1.weight.detach().cpu().numpy()), axis=(1, 2, 3)) range_conv2_after_scaling = np.amax(np.abs( model.conv2.weight.detach().cpu().numpy()), axis=(0, 2, 3)) assert (np.allclose(range_conv1_after_scaling, range_conv2_after_scaling)) assert (np.allclose(baseline_output, output_after_scaling, rtol=1.e-2))
def test_verify_cross_layer_for_multiple_pairs(self): # Get trained MNIST model model = MyModel() # Call API consecutive_layer_list = [(model.conv1, model.conv2), (model.conv3, model.conv4)] w1 = model.conv1.weight.detach().numpy() w2 = model.conv2.weight.detach().numpy() w3 = model.conv3.weight.detach().numpy() CrossLayerScaling.scale_cls_sets(consecutive_layer_list) # check if weights are updating assert not np.allclose(model.conv1.weight.detach().numpy(), w1) assert not np.allclose(model.conv2.weight.detach().numpy(), w2) assert not np.allclose(model.conv3.weight.detach().numpy(), w3)
def test_verify_cross_layer_scaling_depthwise_separable_layer_mobilnet( self): torch.manual_seed(10) model = MockMobileNetV1() model = model.eval() model = model.to(torch.device('cpu')) model.model[0][0].bias = torch.nn.Parameter( torch.rand(model.model[0][0].weight.data.size()[0])) model.model[1][0].bias = torch.nn.Parameter( torch.rand(model.model[1][0].weight.data.size()[0])) model.model[1][3].bias = torch.nn.Parameter( torch.rand(model.model[1][3].weight.data.size()[0])) model.model[2][0].bias = torch.nn.Parameter( torch.rand(model.model[2][0].weight.data.size()[0])) model.model[2][3].bias = torch.nn.Parameter( torch.rand(model.model[2][3].weight.data.size()[0])) random_input = torch.rand(1, 3, 224, 224) baseline_output = model(random_input).detach().numpy() consecutive_layer_list = [ (model.model[0][0], model.model[1][0], model.model[1][3]), (model.model[1][3], model.model[2][0], model.model[2][3]) ] for consecutive_layer in consecutive_layer_list: CrossLayerScaling.scale_cls_set_with_depthwise_layers( consecutive_layer) r1 = np.amax(np.abs( consecutive_layer[0].weight.detach().cpu().numpy()), axis=(1, 2, 3)) r2 = np.amax(np.abs( consecutive_layer[1].weight.detach().cpu().numpy()), axis=(1, 2, 3)) r3 = np.amax(np.abs( consecutive_layer[2].weight.detach().cpu().numpy()), axis=(0, 2, 3)) assert (np.allclose(r1, r2)) assert (np.allclose(r2, r3)) output_after_scaling = model(random_input).detach().numpy() assert (np.allclose(baseline_output, output_after_scaling, rtol=1.e-2))
def test_auto_mobilenetv1(self): torch.manual_seed(10) model = MockMobileNetV1() model.eval() # BN fold fold_all_batch_norms(model, (1, 3, 224, 224)) scale_factors = CrossLayerScaling.scale_model(model, (1, 3, 224, 224)) self.assertEqual(8, len(scale_factors))
def test_verify_cross_layer_scaling_depthwise_separable_layer_multiple_triplets( self): torch.manual_seed(10) model = MockMobileNetV1() model = model.eval() consecutive_layer_list = [ (model.model[0][0], model.model[1][0], model.model[1][3]), (model.model[1][3], model.model[2][0], model.model[2][3]) ] w1 = model.model[0][0].weight.detach().numpy() w2 = model.model[1][3].weight.detach().numpy() w3 = model.model[2][3].weight.detach().numpy() CrossLayerScaling.scale_cls_sets(consecutive_layer_list) assert not np.allclose(model.model[0][0].weight.detach().numpy(), w1) assert not np.allclose(model.model[1][3].weight.detach().numpy(), w2) assert not np.allclose(model.model[2][3].weight.detach().numpy(), w3)
def test_cle_depthwise_transposed_conv2D(self): class TransposedConvModel(torch.nn.Module): def __init__(self): super(TransposedConvModel, self).__init__() self.conv = torch.nn.Conv2d(20, 10, 3) self.bn = torch.nn.BatchNorm2d(10) self.relu = torch.nn.ReLU() self.conv1 = torch.nn.ConvTranspose2d(10, 10, 3, groups=10) self.bn1 = torch.nn.BatchNorm2d(10) self.relu1 = torch.nn.ReLU() self.conv2 = torch.nn.ConvTranspose2d(10, 15, 3) self.bn2 = torch.nn.BatchNorm2d(15) def forward(self, x): # Regular case - conv followed by bn x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) return x torch.manual_seed(10) model = TransposedConvModel() w_shape_1 = copy.deepcopy(model.conv1.weight.shape) w_shape_2 = copy.deepcopy(model.conv2.weight.shape) model = model.eval() input_shapes = (1, 20, 3, 4) random_input = torch.rand(input_shapes) output_before_cle = model(random_input).detach().numpy() folded_pairs = batch_norm_fold.fold_all_batch_norms( model, input_shapes) bn_dict = {} for conv_bn in folded_pairs: bn_dict[conv_bn[0]] = conv_bn[1] cls_set_info_list = CrossLayerScaling.scale_model(model, input_shapes) HighBiasFold.bias_fold(cls_set_info_list, bn_dict) self.assertEqual(w_shape_1, model.conv1.weight.shape) self.assertEqual(w_shape_2, model.conv2.weight.shape) output_after_cle = model(random_input).detach().numpy() self.assertTrue( np.allclose(output_before_cle, output_after_cle, rtol=1.e-2))
def test_auto_transposed_conv2d_model(self): torch.manual_seed(10) model = TransposedConvModel() model.eval() random_input = torch.rand((10, 10, 4, 4)) baseline_output = model(random_input).detach().numpy() scale_factors = CrossLayerScaling.scale_model(model, (10, 10, 4, 4)) output_after_scaling = model(random_input).detach().numpy() self.assertTrue( np.allclose(baseline_output, output_after_scaling, rtol=1.e-2)) self.assertEqual( 10, len(scale_factors[0].cls_pair_info_list[0].scale_factor))
def test_auto_custom_model(self): torch.manual_seed(10) model = MyModel() model.eval() # BN fold fold_all_batch_norms(model, (2, 10, 24, 24)) scale_factors = CrossLayerScaling.scale_model(model, (2, 10, 24, 24)) self.assertEqual(3, len(scale_factors)) self.assertTrue(scale_factors[0].cls_pair_info_list[0]. relu_activation_between_layers) self.assertTrue(scale_factors[1].cls_pair_info_list[0]. relu_activation_between_layers) self.assertFalse(scale_factors[2].cls_pair_info_list[0]. relu_activation_between_layers)
def test_cross_layer_equalization_resnet(self): torch.manual_seed(10) model = models.resnet18(pretrained=True) model = model.eval() folded_pairs = batch_norm_fold.fold_all_batch_norms( model, (1, 3, 224, 224)) bn_dict = {} for conv_bn in folded_pairs: bn_dict[conv_bn[0]] = conv_bn[1] self.assertFalse(isinstance(model.layer2[0].bn1, torch.nn.BatchNorm2d)) w1 = model.layer1[0].conv1.weight.detach().numpy() w2 = model.layer1[0].conv2.weight.detach().numpy() w3 = model.layer1[1].conv1.weight.detach().numpy() cls_set_info_list = CrossLayerScaling.scale_model( model, (1, 3, 224, 224)) # check if weights are updating assert not np.allclose(model.layer1[0].conv1.weight.detach().numpy(), w1) assert not np.allclose(model.layer1[0].conv2.weight.detach().numpy(), w2) assert not np.allclose(model.layer1[1].conv1.weight.detach().numpy(), w3) b1 = model.layer1[0].conv1.bias.data b2 = model.layer1[1].conv2.bias.data HighBiasFold.bias_fold(cls_set_info_list, bn_dict) for i in range(len(model.layer1[0].conv1.bias.data)): self.assertTrue(model.layer1[0].conv1.bias.data[i] <= b1[i]) for i in range(len(model.layer1[1].conv2.bias.data)): self.assertTrue(model.layer1[1].conv2.bias.data[i] <= b2[i])
def test_auto_depthwise_transposed_conv_model(self): torch.manual_seed(0) model = torch.nn.Sequential( torch.nn.Conv2d(5, 10, 3), torch.nn.ReLU(), torch.nn.ConvTranspose2d(10, 10, 3, groups=10), torch.nn.ReLU(), torch.nn.Conv2d(10, 24, 3), torch.nn.ReLU(), torch.nn.ConvTranspose2d(24, 32, 3), ) model.eval() random_input = torch.rand((1, 5, 32, 32)) baseline_output = model(random_input).detach().numpy() scale_factors = CrossLayerScaling.scale_model(model, (1, 5, 32, 32)) output_after_scaling = model(random_input).detach().numpy() output_diff = abs(baseline_output - output_after_scaling) self.assertTrue(np.all(max(output_diff) < 1e-6)) self.assertEqual(2, len(scale_factors)) self.assertEqual(2, len(scale_factors[0].cls_pair_info_list))