def test_checkpoint_backward_compatibility(self): current_model = _create_test_model() # Get checkpoint prior to resnet naming changes (commit d1e8cad) model_d1e8cad = _create_test_model( checkpoint_file=CHECKPOINTS_DIR / "checkpoint_d1e8cad.pt") self.assertTrue(compare_models(current_model, model_d1e8cad, (3, 32, 32))) # Get checkpoint after resnet naming changes (commit 91ee855) model_91ee855 = _create_test_model( checkpoint_file=CHECKPOINTS_DIR / "checkpoint_91ee855.pt") self.assertTrue(compare_models(current_model, model_91ee855, (3, 32, 32)))
def test_creaate_model_from_checkpoint(self): model1 = create_model(model_class=resnet50, model_args={}, init_batch_norm=False, device="cpu") # Simulate imagenet experiment by changing the weights def init(m): if hasattr(m, "weight") and m.weight is not None: m.weight.data.fill_(0.042) model1.apply(init) # Save model checkpoint only, ignoring optimizer and other imagenet # experiment objects state. See ImagenetExperiment.get_state state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile() as checkpoint_file: # Ray save checkpoints as pickled dicts pickle.dump(state, checkpoint_file) checkpoint_file.file.flush() # Load model from checkpoint model2 = create_model(model_class=resnet50, model_args={}, init_batch_norm=False, device="cpu", checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model1, model2, (3, 32, 32)))
def test_almost_identical(self): """Compare a network with itself except for one weight""" model1 = simple_linear_net() model2 = copy.deepcopy(model1) model1._modules["0"].weight.data[0][0] = 1.0 model2._modules["0"].weight.data[0][0] = -1.0 self.assertFalse(compare_models(model1, model2, (32,)))
def test_identical(self): model_args = dict(config=dict( num_classes=3, defaults_sparse=True, )) model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = create_model( model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, ) state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file: pickle.dump(state, checkpoint_file) checkpoint_file.flush() model2 = create_model(model_class=model_class, model_args=model_args, init_batch_norm=False, device=device, checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model, model2, (3, 224, 224)))
def test_fuse_model_conv_bn(model_class): original = _create_test_model(model_class=model_class) conv_layers = { name for name, module in original.named_modules() if isinstance(module, nn.Conv2d) } bn_layers = { name for name, module in original.named_modules() if isinstance(module, nn.BatchNorm2d) } # Fuse conv and bn only fused = copy.deepcopy(original) fused.fuse_model(fuse_relu=False) # Check if BN layers were removed assert all( isinstance(module, nn.Identity) for name, module in fused.named_modules() if name in bn_layers) # Check if all Conv/BN were merged conv_bn_layers = { name for name, module in fused.named_modules() if isinstance(module, nni.ConvBn2d) } assert conv_layers == conv_bn_layers # Validate output assert compare_models(original, fused, (3, 224, 224))
def test_simple_cnn(self): """Compare a network with itself after batchnorm is removed.""" model = SimpleCNN() train_randomly(model) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_gsc(self): """ Compare the GSC network after batchnorm is removed. """ model = gsc_sparse_cnn(pretrained=True) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_cnn_more_out_channels(self): """Compare another network with itself after batchnorm is removed.""" model = SimpleCNN( cnn_out_channels=16, linear_units=20, ) train_randomly(model) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_simple_cnn(self): """Compare a network with itself after batchnorm is removed.""" model = create_simple_cnn() train_randomly(model) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_gsc(self): """ Compare the GSC network after batchnorm is removed. """ model = gsc_sparse_cnn(pretrained=True) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (1, 32, 32)))
def test_cnn_sparse_weights(self): """ Compare a network with 3 in_channels with itself after batchnorm is removed. """ model = SimpleCNN( in_channels=3, cnn_out_channels=4, linear_units=5, sparse_weights=True, ) train_randomly(model, in_channels=3) model.apply(rezero_weights) model2 = remove_batchnorm(model) self.assertLess(len(model2._modules.keys()), len(model._modules.keys())) self.assertTrue(compare_models(model, model2, (3, 32, 32)))
def test_serialization(self): model1 = simple_linear_net() model2 = simple_linear_net() def init(m): if hasattr(m, "weight") and m.weight is not None: m.weight.data.fill_(42.0) model2.apply(init) with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) buffer.seek(0) state_dict = deserialize_state_dict(buffer) model2.load_state_dict(state_dict) self.assertTrue(compare_models(model1, model2, (32, )))
def test_cnn_sparse_weights(self): """ Compare a network with 3 in_channels with itself after batchnorm is removed. """ model = create_simple_cnn( in_channels=3, cnn_out_channels=4, linear_units=5, sparse_weights=True, ) train_randomly(model, in_channels=3) model2 = remove_batchnorm(model) expected_modules = set(name for name, m in model.named_children() if not isinstance(m, BATCH_NORM_CLASSES)) actual_modules = set(name for name, m in model2.named_children()) self.assertEqual(actual_modules, expected_modules) self.assertTrue(compare_models(model, model2, (3, 32, 32)))
def test_fuse_model_conv_bn_relu(model_class): original = _create_test_model(model_class=model_class) conv_layers = { name for name, module in original.named_modules() if isinstance(module, nn.Conv2d) } bn_layers = { name for name, module in original.named_modules() if isinstance(module, nn.BatchNorm2d) } # Get all ReLU except for "post_activation" relu_layers = { name for name, module in original.named_modules() if isinstance(module, nn.ReLU) and "post_activation" not in name } # Fuse conv, bn and relu fused = copy.deepcopy(original) fused.fuse_model(fuse_relu=True) # Check if BN+ReLU layers were removed assert all( isinstance(module, nn.Identity) for name, module in fused.named_modules() if name in bn_layers | relu_layers) # Check if all Conv/BN/Relu were merged conv_bn_layers = { name for name, module in fused.named_modules() if isinstance(module, (nni.ConvBn2d, nni.ConvBnReLU2d)) } assert conv_layers == conv_bn_layers # Validate output assert compare_models(original, fused, (3, 224, 224))
def test_identical(self): model_args = dict(num_classes=3, ) model_class = nupic.research.frameworks.pytorch.models.resnets.resnet50 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model_class(**model_args) model.to(device) state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile(delete=True) as checkpoint_file: pickle.dump(state, checkpoint_file) checkpoint_file.flush() model2 = model_class(**model_args) model2.to(device) load_state_from_checkpoint(model2, checkpoint_file.name, device) self.assertTrue(compare_models(model, model2, (3, 224, 224)))
def test_creaate_model_from_checkpoint(self): model1 = _create_test_model() # Save model checkpoint only, ignoring optimizer and other imagenet # experiment objects state. See ImagenetExperiment.get_state state = {} with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) state["model"] = buffer.getvalue() with tempfile.NamedTemporaryFile() as checkpoint_file: # Ray save checkpoints as pickled dicts pickle.dump(state, checkpoint_file) checkpoint_file.file.flush() # Load model from checkpoint model2 = create_model( model_class=resnet50, model_args=TEST_MODEL_ARGS, init_batch_norm=False, device="cpu", checkpoint_file=checkpoint_file.name) self.assertTrue(compare_models(model1, model2, (3, 32, 32)))
def test_identical(self): """Compare a network with itself""" model = simple_linear_net() self.assertTrue(compare_models(model, model, (32,)))
def test_different(self): """Compare two random networks""" model1 = simple_linear_net() model2 = simple_linear_net() self.assertFalse(compare_models(model1, model2, (32,)))
def test_conv_identical(self): """Compare a conv network with itself""" model = simple_conv_net() self.assertTrue(compare_models(model, model, (1, 32, 32)))
def test_conv_different(self): """Compare two random conv networks""" model1 = simple_conv_net() model2 = simple_conv_net() self.assertFalse(compare_models(model1, model2, (1, 32, 32)))