def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None): current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() # Determine the number of weights that need to be pruned. number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()]) number_of_weights_to_prune = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model weights. weights = {k: v.clone().cpu().detach().numpy() for k, v in trained_model.state_dict().items() if k in prunable_tensors} # Create a vector of all the unpruned weights in the model. weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()]) threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune] new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v)) for k, v in weights.items()}) for k in current_mask: if k not in new_mask: new_mask[k] = current_mask[k] # Randomize globally throughout all prunable layers. new_mask = Mask(unvectorize(shuffle_tensor(vectorize(new_mask), seed=42), new_mask)) return new_mask
def test_load_state_dict(self): mask = Mask.ones_like(self.model) mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) self.model.layer.weight.data = torch.arange(10, 20, dtype=torch.float32) pruned_model = PrunedModel(self.model, mask) self.assertEqual(70.0, pruned_model(self.example).item()) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(67.0, pruned_model(self.example).item()) # Save the old state dict. state_dict = pruned_model.state_dict() # Create a new model. self.model = InnerProductModel(10) mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.assertEqual(45.0, pruned_model(self.example).item()) # Load the state dict. pruned_model.load_state_dict(state_dict) self.assertEqual(67.0, pruned_model(self.example).item()) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(64.3, np.round(pruned_model(self.example).item(), 1))
def test_create_from_tensor(self): m = Mask({'hello': torch.ones([2, 3]), 'world': torch.zeros([5, 6])}) self.assertEqual(len(m), 2) self.assertEqual(len(m.keys()), 2) self.assertEqual(len(m.values()), 2) self.assertEqual(set(m.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
def branch_function(self, target_model_name: str = None, block_mapping: str = None, start_at_step_zero: bool = False): # Process the mapping # A valid string format of a mapping is like: # `0:0;1:1,2;2:3,4;3:5,6;4:7,8` if 'cifar' in target_model_name and 'resnet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'imagenet' in target_model_name and 'resnet' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'vggnfc' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'cifar' in target_model_name and 'vgg' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'mnist' in target_model_name and 'lenet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) else: raise NotImplementedError( 'Other mapping cases not implemented yet') # Load source model at `train_start_step` src_mask = Mask.load(self.level_root) start_step = self.lottery_desc.str_to_step( '0it' ) if start_at_step_zero else self.lottery_desc.train_start_step # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask) src_model = models.registry.load(self.level_root, start_step, self.lottery_desc.model_hparams) # Create target model target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams) target_model_hparams.model_name = target_model_name target_model = models.registry.get(target_model_hparams) target_ones_mask = Mask.ones_like(target_model) # Do the morphism target_sd = change_depth(target_model_name, src_model.state_dict(), target_model.state_dict(), mappings) target_model.load_state_dict(target_sd) target_mask = change_depth(target_model_name, src_mask, target_ones_mask, mappings) target_model = PrunedModel(target_model, target_mask) # Save and run a standard train target_mask.save(self.branch_root) train.standard_train(target_model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def _prune_level(self, level: int): new_location = self.desc.run_path(self.replicate, level) if Mask.exists(new_location): return if level == 0: Mask.ones_like(models.registry.get(self.desc.model_hparams)).save(new_location) else: old_location = self.desc.run_path(self.replicate, level-1) model = models.registry.load(old_location, self.desc.train_end_step, self.desc.model_hparams, self.desc.train_outputs) pruning.registry.get(self.desc.pruning_hparams)(model, Mask.load(old_location)).save(new_location)
def branch_function(self, start_at_step_zero: bool = False): model = PrunedModel( models.registry.get(self.lottery_desc.model_hparams), Mask.load(self.level_root)) start_step = self.lottery_desc.str_to_step( '0it' ) if start_at_step_zero else self.lottery_desc.train_start_step Mask.load(self.level_root).save(self.branch_root) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, training_hparams: hparams.TrainingHparams = None, dataset_hparams: hparams.DatasetHparams = None, data_order_seed: int = None): current_mask = Mask.ones_like( trained_model) if current_mask is None else current_mask current_mask_numpy = current_mask.numpy() # Determine the number of weights that need to be pruned. number_of_remaining_weights = np.sum( [np.sum(v) for v in current_mask_numpy.values()]) number_of_weights_to_prune = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set( pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model score. scores = Strategy.get_score(trained_model, current_mask, prunable_tensors, training_hparams, dataset_hparams, data_order_seed) # Get the model weights. # weights = {k: v.clone().cpu().detach().numpy() # for k, v in trained_model.state_dict().items() # if k in prunable_tensors} # Create a vector of all the unpruned weights in the model. # weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()]) score_vector = np.concatenate( [v[current_mask_numpy[k] == 1] for k, v in scores.items()]) threshold = np.sort(np.abs(score_vector))[number_of_weights_to_prune] new_mask = Mask({ k: np.where( np.abs(v) > threshold, current_mask_numpy[k], np.zeros_like(v)) for k, v in scores.items() }) for k in current_mask_numpy: if k not in new_mask: new_mask[k] = current_mask_numpy[k] return new_mask
def branch_function(self, retrain_d: hparams.DatasetHparams, retrain_t: hparams.TrainingHparams, start_at_step_zero: bool = False, transfer_learn: bool = False): # Get the mask and model. if transfer_learn: m = models.registry.load(self.level_root, self.lottery_desc.train_end_step, self.lottery_desc.model_hparams) else: m = models.registry.load(self.level_root, self.lottery_desc.train_start_step, self.lottery_desc.model_hparams) m = PrunedModel(m, Mask.load(self.level_root)) start_step = Step.from_iteration( 0 if start_at_step_zero else self.lottery_desc.train_start_step.iteration, datasets.registry.iterations_per_epoch(retrain_d)) train.standard_train(m, self.branch_root, retrain_d, retrain_t, start_step=start_step, verbose=self.verbose)
def test_state_dict(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) state_dict = pruned_model.state_dict() self.assertEqual(set(['model.layer.weight', 'mask_layer___weight']), state_dict.keys())
def branch_function(self, seed: int, strategy: str = 'sparse_global', start_at: str = 'rewind', layers_to_ignore: str = ''): # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = models.registry.load(self.pretrain_root, state_step, self.lottery_desc.model_hparams) # Get the current level mask and get the target pruning ratio mask = Mask.load(self.level_root) sparsity_ratio = mask.get_sparsity_ratio() target_pruning_fraction = 1.0 - sparsity_ratio # Run pruning pruning_hparams = copy.deepcopy(self.lottery_desc.pruning_hparams) pruning_hparams.pruning_strategy = strategy pruning_hparams.pruning_fraction = target_pruning_fraction new_mask = pruning.registry.get(pruning_hparams)( model, Mask.ones_like(model), self.lottery_desc.training_hparams, self.lottery_desc.dataset_hparams, seed ) new_mask.save(self.branch_root) repruned_model = PrunedModel(model.to(device=get_platform().cpu_device), new_mask) # Run training train.standard_train(repruned_model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def test_mask_with_ones_backward(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(44.0, pruned_model(self.example).item())
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None): current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy() # number of initializations num_inits = next(iter(current_mask.values())).shape[0] assert np.array([num_inits == v.shape[0] for v in current_mask.values()]).all() # Determine the number of weights that need to be pruned. number_of_remaining_weights_per_init = np.sum([np.sum(v) for v in current_mask.values()]) // num_inits number_of_weights_to_prune_per_init = np.ceil( pruning_hparams.pruning_fraction * number_of_remaining_weights_per_init).astype(int) # Determine which layers can be pruned. prunable_tensors = set(trained_model.prunable_layer_names) if pruning_hparams.pruning_layers_to_ignore: prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(',')) # Get the model weights. weights = {k: v.clone().cpu().detach().numpy() for k, v in trained_model.state_dict().items() if k in prunable_tensors} # Create a vector of all the unpruned weights in the model. weight_vectors = [ np.concatenate( [ v[init_id, ...][current_mask[k][init_id,...] == 1] for k, v in weights.items() ] ) for init_id in range(num_inits)] thresholds = np.array([ np.sort(np.abs(wv))[number_of_weights_to_prune_per_init] for wv in weight_vectors ]) mask_dict = {} for k, v in weights.items(): threshold_tensor = thresholds.reshape(-1, *[1 for _ in range(v.ndim-1)]) threshold_tensor = np.tile(threshold_tensor, v.shape[1:]) mask_dict[k] = np.where(np.abs(v) > threshold_tensor, current_mask[k], np.zeros_like(v)) new_mask = Mask(mask_dict) for k in current_mask: if k not in new_mask: new_mask[k] = current_mask[k] return new_mask
def test_save(self): state1 = self.get_state(self.model) mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) pruned_model.save(self.root, Step.zero(20)) self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20)))) self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20)))) self.assertStateEqual(state1, self.get_state(self.model))
def test_ones_like(self): model = models.registry.get(models.registry.get_default_hparams('cifar_resnet_20').model_hparams) m = Mask.ones_like(model) for k, v in model.state_dict().items(): if k in model.prunable_layer_names: self.assertIn(k, m) self.assertEqual(list(m[k].shape), list(v.shape)) self.assertTrue((m[k] == 1).all()) else: self.assertNotIn(k, m)
def _prune_level(self, level: int): new_location = self.desc.run_path(self.replicate, level) if Mask.exists(new_location): return if level == 0: Mask.ones_like(models.registry.get(self.desc.model_hparams)).save( new_location) # level=0일때는 mask 다 1 => weight다 살리기 else: old_location = self.desc.run_path( self.replicate, level - 1) # 아니라면 old location = 직전 level 에 저장된 path인 run_path model = models.registry.load( old_location, self.desc.train_end_step, self.desc.model_hparams, self.desc.train_outputs) # pruning이기때문에 train_end_step 불러오는것임! pruning.registry.get(self.desc.pruning_hparams)( model, Mask.load(old_location) ).save( new_location ) # registry.get 에는 return partial 부분에 .prune이 있어 프루닝이 되고 이후 new_location에 저장.
def test_dict_behavior(self): m = Mask() self.assertEqual(len(m), 0) self.assertEqual(len(m.keys()), 0) self.assertEqual(len(m.values()), 0) m['hello'] = np.ones([2, 3]) m['world'] = np.zeros([5, 6]) self.assertEqual(len(m), 2) self.assertEqual(len(m.keys()), 2) self.assertEqual(len(m.values()), 2) self.assertEqual(set(m.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world'])) del m['hello'] self.assertEqual(len(m), 1) self.assertEqual(len(m.keys()), 1) self.assertEqual(len(m.values()), 1) self.assertEqual(set(m.keys()), set(['world'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m['world']))
def test_level3_4it_pretrain2it(self): self.desc.pretrain_dataset_hparams = copy.deepcopy( self.desc.dataset_hparams) self.desc.pretrain_training_hparams = copy.deepcopy( self.desc.training_hparams) self.desc.pretrain_training_hparams.training_steps = '2it' self.desc.training_hparams.training_steps = '4it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() # Check that the pretrain weights are present. pretrain_root = self.desc.run_path(2, 'pretrain') self.assertLevelFilesPresent(pretrain_root, self.to_step('0it'), self.to_step('2it'), masks=False) # Load the pretrain and level0 start weights to ensure they're the same. pretrain_end_weights = paths.model(self.desc.run_path(2, 'pretrain'), self.desc.pretrain_end_step) pretrain_end_weights = { k: v.numpy() for k, v in torch.load(pretrain_end_weights).items() } level0_weights = paths.model(self.desc.run_path(2, 0), self.desc.train_start_step) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } self.assertStateEqual(pretrain_end_weights, level0_weights) # Evaluate each of the pruning levels. for level in range(0, 2): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('2it'), self.to_step('4it')) # Ensure that the initial weights are a masked version of the level 0 weights # (which are identical to the weights at the end of pretraining). mask = Mask.load(level_root).numpy() level_weights = paths.model(level_root, self.desc.train_start_step) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def _train_level(self, level: int): location = self.desc.run_path(self.replicate, level) if models.registry.exists(location, self.desc.train_end_step): return model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, self.desc.model_hparams, self.desc.train_outputs) pruned_model = PrunedModel(model, Mask.load(location)) pruned_model.save(location, self.desc.train_start_step) if self.verbose and get_platform().is_primary_process: print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82) train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, start_step=self.desc.train_start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch, weight_save_steps=self.weight_save_steps)
def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind', layers_to_ignore: str = ''): # Randomize the mask. mask = Mask.load(self.level_root) # Randomize while keeping the same layerwise proportions as the original mask. if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed)) # Randomize globally throughout all prunable layers. elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask)) # Randomize evenly across all layers. elif strategy == 'even': sparsity = mask.sparsity for i, k in sorted(mask.keys()): layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size), torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size)) mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size) # Identity. elif strategy == 'identity': pass # Error. else: raise ValueError(f'Invalid strategy: {strategy}') # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Save the new mask. mask.save(self.branch_root) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def test_bad_inputs(self): m = Mask() with self.assertRaises(ValueError): m[''] = np.ones([2, 3]) with self.assertRaises(ValueError): m[6] = np.ones([2, 3]) with self.assertRaises(ValueError): m['hello'] = [[0, 1], [1, 0]] with self.assertRaises(ValueError): m['hello'] = np.array([[0, 1], [2, 0]])
def test_level3_2it(self): self.desc.training_hparams.training_steps = '2it' LotteryRunner(replicate=2, levels=3, desc=self.desc, verbose=False).run() level0_weights = paths.model(self.desc.run_path(2, 0), self.to_step('0it')) level0_weights = { k: v.numpy() for k, v in torch.load(level0_weights).items() } for level in range(0, 4): level_root = self.desc.run_path(2, level) self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) # Check the mask. pct = 0.8**level mask = Mask.load(level_root).numpy() # Check the mask itself. total, total_present = 0.0, 0.0 for v in mask.values(): total += v.size total_present += np.sum(v) self.assertTrue(np.allclose(pct, total_present / total, atol=0.01)) # Check the sparsity report. with open(paths.sparsity_report(level_root)) as fp: sparsity_report = json.loads(fp.read()) self.assertTrue( np.allclose(pct, sparsity_report['unpruned'] / sparsity_report['total'], atol=0.01)) # Ensure that the initial weights are a masked version of the level 0 weights. level_weights = paths.model(level_root, self.to_step('0it')) level_weights = { k: v.numpy() for k, v in torch.load(level_weights).items() } self.assertStateEqual( level_weights, {k: v * mask.get(k, 1) for k, v in level0_weights.items()})
def run(self): if self.verbose and get_platform().is_primary_process: print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82) print(self.desc.display) print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n') self.desc.save(self.desc.run_path(self.replicate)) #TODO: make mask and model init paths configurable init_path = os.path.join(get_platform().root, 'resnet18_lth') model = models.registry.load(init_path, Step.from_str('2ep218it', 1000), self.desc.model_hparams, self.desc.train_outputs) pruned_model = PrunedModel(model, Mask.load(init_path)) train.standard_train( #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate), pruned_model, self.desc.run_path(self.replicate), self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
def test_save_load_exists(self): self.assertFalse(Mask.exists(self.root)) self.assertFalse(os.path.exists(paths.mask(self.root))) m = Mask({'hello': np.ones([2, 3]), 'world': np.zeros([5, 6])}) m.save(self.root) self.assertTrue(os.path.exists(paths.mask(self.root))) self.assertTrue(Mask.exists(self.root)) m2 = Mask.load(self.root) self.assertEqual(len(m2), 2) self.assertEqual(len(m2.keys()), 2) self.assertEqual(len(m2.values()), 2) self.assertEqual(set(m2.keys()), set(['hello', 'world'])) self.assertTrue(np.array_equal(np.ones([2, 3]), m2['hello'])) self.assertTrue(np.array_equal(np.zeros([5, 6]), m2['world']))
def test_level0_2it(self): self.desc.training_hparams.training_steps = '2it' LotteryRunner(replicate=2, levels=0, desc=self.desc, verbose=False).run() level_root = self.desc.run_path(2, 0) # Ensure the important files are there. self.assertLevelFilesPresent(level_root, self.to_step('0it'), self.to_step('2it')) # Ensure that the mask is all 1's. mask = Mask.load(level_root) for v in mask.numpy().values(): self.assertTrue(np.all(np.equal(v, 1))) with open(paths.sparsity_report(level_root)) as fp: sparsity_report = json.loads(fp.read()) self.assertEqual( sparsity_report['unpruned'] / sparsity_report['total'], 1)
def test_with_mask(self): mask = Mask() mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) pruned_model = PrunedModel(self.model, mask) # Check that the forward pass gives the correct value. self.assertEqual(20.0, pruned_model(self.example).item()) # Check that the appropriate weights have been zeroed out. self.assertTrue(np.array_equal(self.model.state_dict()['layer.weight'].numpy(), np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0]))) # Perform a backward pass. self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(22.0, pruned_model(self.example).item()) # Verify the weights. self.assertTrue(np.allclose(self.model.state_dict()['layer.weight'].numpy(), np.array([0.4, 0, 2.4, 0, 4.4, 0, 6.4, 0, 8.4, 0])))
def __init__(self, model: Model, mask: Mask): if isinstance(model, PrunedModel): raise ValueError('Cannot nest pruned models.') super(PrunedModel, self).__init__() self.model = model for k in self.model.prunable_layer_names: if k not in mask: raise ValueError('Missing mask value {}.'.format(k)) if not np.array_equal(mask[k].shape, np.array(self.model.state_dict()[k].shape)): raise ValueError( 'Incorrect mask shape {} for tensor {}.'.format( mask[k].shape, k)) for k in mask: if k not in self.model.prunable_layer_names: raise ValueError( 'Key {} found in mask but is not a valid model tensor.'. format(k)) for k, v in mask.items(): self.register_buffer(PrunedModel.to_mask_name(k), v.float()) self._apply_mask()
def test_with_incorrect_shape(self): mask = Mask.ones_like(self.model) mask['layer.weight'] = np.ones(30) with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def test_with_excess_mask_value(self): mask = Mask.ones_like(self.model) mask['layer2.weight'] = np.ones(20) with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def test_with_missing_mask_value(self): mask = Mask.ones_like(self.model) del mask['layer.weight'] with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def test_mask_with_ones_forward(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.assertEqual(45.0, pruned_model(self.example).item())