def test_state_dict(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) state_dict = pruned_model.state_dict() self.assertEqual(set(['model.layer.weight', 'mask_layer___weight']), state_dict.keys())
def test_mask_with_ones_backward(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(44.0, pruned_model(self.example).item())
def test_save(self): state1 = self.get_state(self.model) mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) pruned_model.save(self.root, Step.zero(20)) self.assertTrue(os.path.exists(paths.model(self.root, Step.zero(20)))) self.model.load_state_dict(torch.load(paths.model(self.root, Step.zero(20)))) self.assertStateEqual(state1, self.get_state(self.model))
def _train_level(self, level: int): location = self.desc.run_path(self.replicate, level) if models.registry.exists(location, self.desc.train_end_step): return model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, self.desc.model_hparams, self.desc.train_outputs) pruned_model = PrunedModel(model, Mask.load(location)) pruned_model.save(location, self.desc.train_start_step) if self.verbose and get_platform().is_primary_process: print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82) train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, start_step=self.desc.train_start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch, weight_save_steps=self.weight_save_steps)
def branch_function(self, retrain_d: hparams.DatasetHparams, retrain_t: hparams.TrainingHparams, start_at_step_zero: bool = False, transfer_learn: bool = False): # Get the mask and model. if transfer_learn: m = models.registry.load(self.level_root, self.lottery_desc.train_end_step, self.lottery_desc.model_hparams) else: m = models.registry.load(self.level_root, self.lottery_desc.train_start_step, self.lottery_desc.model_hparams) m = PrunedModel(m, Mask.load(self.level_root)) start_step = Step.from_iteration( 0 if start_at_step_zero else self.lottery_desc.train_start_step.iteration, datasets.registry.iterations_per_epoch(retrain_d)) train.standard_train(m, self.branch_root, retrain_d, retrain_t, start_step=start_step, verbose=self.verbose)
def get_score(trained_model: models.base.Model, current_mask: Mask, prunable_tensors: set, training_hparams: hparams.TrainingHparams, dataset_hparams: hparams.DatasetHparams, data_order_seed: int = None): pruned_model = PrunedModel(trained_model, current_mask).to(device=get_platform().torch_device) pruned_model._clear_grad() # pruned_model._enable_mask_gradient() # Calculate the gradient train.accumulate_gradient( training_hparams, pruned_model, dataset_hparams, data_order_seed, verbose=False ) # Calculate the score scores = dict() for name, param in pruned_model.model.named_parameters(): if hasattr(pruned_model, PrunedModel.to_mask_name(name)) and name in prunable_tensors: scores[name] = (param.grad * param).abs().clone().cpu().detach().numpy() score_vector = np.concatenate([v.reshape(-1) for k, v in scores.items()]) norm = np.sum(score_vector) for k in scores.keys(): scores[k] /= norm # Clean up pruned_model._clear_grad() # model._disable_mask_gradient() return scores
def branch_function(self, target_model_name: str = None, block_mapping: str = None, start_at_step_zero: bool = False): # Process the mapping # A valid string format of a mapping is like: # `0:0;1:1,2;2:3,4;3:5,6;4:7,8` if 'cifar' in target_model_name and 'resnet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'imagenet' in target_model_name and 'resnet' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'vggnfc' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'cifar' in target_model_name and 'vgg' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'mnist' in target_model_name and 'lenet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) else: raise NotImplementedError( 'Other mapping cases not implemented yet') # Load source model at `train_start_step` src_mask = Mask.load(self.level_root) start_step = self.lottery_desc.str_to_step( '0it' ) if start_at_step_zero else self.lottery_desc.train_start_step # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask) src_model = models.registry.load(self.level_root, start_step, self.lottery_desc.model_hparams) # Create target model target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams) target_model_hparams.model_name = target_model_name target_model = models.registry.get(target_model_hparams) target_ones_mask = Mask.ones_like(target_model) # Do the morphism target_sd = change_depth(target_model_name, src_model.state_dict(), target_model.state_dict(), mappings) target_model.load_state_dict(target_sd) target_mask = change_depth(target_model_name, src_mask, target_ones_mask, mappings) target_model = PrunedModel(target_model, target_mask) # Save and run a standard train target_mask.save(self.branch_root) train.standard_train(target_model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def test_with_mask(self): mask = Mask() mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) pruned_model = PrunedModel(self.model, mask) # Check that the forward pass gives the correct value. self.assertEqual(20.0, pruned_model(self.example).item()) # Check that the appropriate weights have been zeroed out. self.assertTrue(np.array_equal(self.model.state_dict()['layer.weight'].numpy(), np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0]))) # Perform a backward pass. self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(22.0, pruned_model(self.example).item()) # Verify the weights. self.assertTrue(np.allclose(self.model.state_dict()['layer.weight'].numpy(), np.array([0.4, 0, 2.4, 0, 4.4, 0, 6.4, 0, 8.4, 0])))
def branch_function(self, start_at_step_zero: bool = False): model = PrunedModel( models.registry.get(self.lottery_desc.model_hparams), Mask.load(self.level_root)) start_step = self.lottery_desc.str_to_step( '0it' ) if start_at_step_zero else self.lottery_desc.train_start_step Mask.load(self.level_root).save(self.branch_root) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind', layers_to_ignore: str = ''): # Randomize the mask. mask = Mask.load(self.level_root) # Randomize while keeping the same layerwise proportions as the original mask. if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed)) # Randomize globally throughout all prunable layers. elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask)) # Randomize evenly across all layers. elif strategy == 'even': sparsity = mask.sparsity for i, k in sorted(mask.keys()): layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size), torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size)) mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size) # Identity. elif strategy == 'identity': pass # Error. else: raise ValueError(f'Invalid strategy: {strategy}') # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Save the new mask. mask.save(self.branch_root) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask) train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def branch_function(self, seed: int, strategy: str = 'sparse_global', start_at: str = 'rewind', layers_to_ignore: str = ''): # Reset the masks of any layers that shouldn't be pruned. if layers_to_ignore: for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k]) # Determine the start step. if start_at == 'init': start_step = self.lottery_desc.str_to_step('0ep') state_step = start_step elif start_at == 'end': start_step = self.lottery_desc.str_to_step('0ep') state_step = self.lottery_desc.train_end_step elif start_at == 'rewind': start_step = self.lottery_desc.train_start_step state_step = start_step else: raise ValueError(f'Invalid starting point {start_at}') # Train the model with the new mask. model = models.registry.load(self.pretrain_root, state_step, self.lottery_desc.model_hparams) # Get the current level mask and get the target pruning ratio mask = Mask.load(self.level_root) sparsity_ratio = mask.get_sparsity_ratio() target_pruning_fraction = 1.0 - sparsity_ratio # Run pruning pruning_hparams = copy.deepcopy(self.lottery_desc.pruning_hparams) pruning_hparams.pruning_strategy = strategy pruning_hparams.pruning_fraction = target_pruning_fraction new_mask = pruning.registry.get(pruning_hparams)( model, Mask.ones_like(model), self.lottery_desc.training_hparams, self.lottery_desc.dataset_hparams, seed ) new_mask.save(self.branch_root) repruned_model = PrunedModel(model.to(device=get_platform().cpu_device), new_mask) # Run training train.standard_train(repruned_model, self.branch_root, self.lottery_desc.dataset_hparams, self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
def run(self): if self.verbose and get_platform().is_primary_process: print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82) print(self.desc.display) print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n') self.desc.save(self.desc.run_path(self.replicate)) #TODO: make mask and model init paths configurable init_path = os.path.join(get_platform().root, 'resnet18_lth') model = models.registry.load(init_path, Step.from_str('2ep218it', 1000), self.desc.model_hparams, self.desc.train_outputs) pruned_model = PrunedModel(model, Mask.load(init_path)) train.standard_train( #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate), pruned_model, self.desc.run_path(self.replicate), self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
def _train_level(self, level: int): location = self.desc.run_path(self.replicate, level) # 해당 level path if models.registry.exists(location, self.desc.train_end_step): # image PATH 가 없으면 make directory => 내가 추가한것. if not os.path.exists(f'{location}\Distribution_of_Weight'): os.makedirs(f'{location}\Distribution_of_Weight') IMAGE_PATH = f'{location}\Distribution_of_Weight' # Weight Load & Weight Plotting => 내가 추가 print('\nPlotting Location is: ', location) model = models.registry.load( self.desc.run_path(self.replicate, 0), self.desc.train_start_step, self.desc.model_hparams, self.desc.train_outputs) # Load Original_Save Parameter : As the batch size changes, the ep should be adjusted. default: batch_size=16 for ep, iteration in [[0, 0], [149, 234]]: model.load_state_dict( torch.load('{}\model_ep{}_it{}.pth'.format( location, ep, iteration))) model.eval() print("\nmodel_ep{}_it{}.pth".format(ep, iteration)) for param_tensor in model.state_dict(): #print(param_tensor, "\n", model.state_dict()[param_tensor]) tensor = model.state_dict()[param_tensor] tensor = tensor.numpy() #tensor에서 weight 만 추출 tensor = tensor[0] #print(tensor) tensor = tensor.reshape(-1) sns.kdeplot(tensor) plt.savefig( '{}\Distribution_of_weights_level{}_ep{}.png'. format(IMAGE_PATH, level, ep)) plt.clf() """ # Load Weights before & after Training => 내가 추가 for ep in range(14): for Label in ["Before","After"]: print("\n {} Training \n".format(Label)) model.load_state_dict(torch.load('{}\weights\Record_Weights_{}_ep{}.pth'.format(location,Label, ep)),strict = False) model.eval() for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor]) #print(param_tensor, "\t", model.state_dict()[param_tensor].size()) """ return # 만약 트레이닝을 전에 시켰다면 위에서 return 해서 끝나버려서 여기까지 안옴. model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step, self.desc.model_hparams, self.desc.train_outputs) # level 7일때부터(즉 21 % 남았을때 부터)는 double_param_level = True 이면 초기화시 masking 후 2배 적용, # layer_differnt = True 이면 layer 별로 masking 한후 각기 다르게 상수배 해줌. if level >= 7: pruned_model = PrunedModel(model, Mask.load(location), double_param_level=False, layer_different=True) else: pruned_model = PrunedModel(model, Mask.load(location)) # model, mask 불러오기 pruned_model.save(location, self.desc.train_start_step) # pruned된 모델 저장 #print(f'Prunded Model is: {PrunedModel(model,Mask.load(location))}\n') #print(f'Mask.load is: {Mask.load(location)}') if self.verbose and get_platform().is_primary_process: print('-' * 82 + '\nPruning Level {}\n'.format(level) + '-' * 82) # level = 0, level = 1... 등 pruning level 표시 """ # 내가 추가 (tensor 다루는법이 있어서 남겨둠) if level > 7: pruned_model.eval() for param_tensor in pruned_model.state_dict(): tensor = pruned_model.state_dict()[param_tensor] tensor = tensor.numpy() tensor = tensor * 2 pruned_model.state_dict()[param_tensor] = tensor pruned_model.save(location, self.desc.train_start_step) """ train.standard_train( pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams, start_step=self.desc.train_start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch ) # training 하기 => trainig 할때는 trian_start_step의 parameter 사용
def branch_function( self, strategy: str, prune_fraction: float, prune_experiment: str = 'main', prune_step: str = '0ep0it', # The step for states used to prune. prune_highest: bool = False, prune_iterations: int = 1, randomize_layerwise: bool = False, state_experiment: str = 'main', state_step: str = '0ep0it', # The step of the state to use alongside the pruning mask. start_step: str = '0ep0it', # The step at which to start the learning rate schedule. seed: int = None, reinitialize: bool = False): # Get the steps for each part of the process. iterations_per_epoch = datasets.registry.iterations_per_epoch( self.desc.dataset_hparams) prune_step = Step.from_str(prune_step, iterations_per_epoch) state_step = Step.from_str(state_step, iterations_per_epoch) start_step = Step.from_str(start_step, iterations_per_epoch) seed = self.replicate if seed is None else seed # Try to load the mask. try: mask = Mask.load(self.branch_root) except: mask = None result_folder = "Data_Distribution/" if reinitialize: result_folder = "Data_Distribution_Reinit/" elif randomize_layerwise: result_folder = "Data_Distribution__Randomize_Layerwise/" if not mask and get_platform().is_primary_process: # Gather the weights that will be used for pruning. prune_path = self.desc.run_path(self.replicate, prune_experiment) prune_model = models.registry.load(prune_path, prune_step, self.desc.model_hparams) # Ensure that a valid strategy is available. strategy_class = [s for s in strategies if s.valid_name(strategy)] if not strategy_class: raise ValueError(f'No such pruning strategy {strategy}') if len(strategy_class) > 1: raise ValueError('Multiple matching strategies') strategy_instance = strategy_class[0](strategy, self.desc, seed) # Run the strategy for each iteration. mask = Mask.ones_like(prune_model) iteration_fraction = 1 - (1 - prune_fraction)**( 1 / float(prune_iterations)) if iteration_fraction > 0: for it in range(0, prune_iterations): # Make a defensive copy of the model and mask out the pruned weights. prune_model2 = copy.deepcopy(prune_model) with torch.no_grad(): for k, v in prune_model2.named_parameters(): v.mul_(mask.get(k, 1)) # Compute the scores. scores = strategy_instance.score(prune_model2, mask) # Prune. mask = unvectorize( prune(vectorize(scores), iteration_fraction, not prune_highest, mask=vectorize(mask)), mask) # Shuffle randomly per layer. if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed) mask = Mask({k: v.clone().detach() for k, v in mask.items()}) mask.save(self.branch_root) # Plot graphs (Move below mask save?) # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) # pdb.set_trace() # Load the mask. get_platform().barrier() mask = Mask.load(self.branch_root) # Determine the start step. state_path = self.desc.run_path(self.replicate, state_experiment) if reinitialize: model = models.registry.get(self.desc.model_hparams) else: model = models.registry.load(state_path, state_step, self.desc.model_hparams) # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder) original_model = copy.deepcopy(model) model = PrunedModel(model, mask) # pdb.set_trace() train.standard_train(model, self.branch_root, self.desc.dataset_hparams, self.desc.training_hparams, start_step=start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch) weights_analysis(original_model, strategy, reinitialize, randomize_layerwise, "original") weights_analysis(model, strategy, reinitialize, randomize_layerwise, "pruned")
def test_with_incorrect_shape(self): mask = Mask.ones_like(self.model) mask['layer.weight'] = np.ones(30) with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def run(self): location = self.desc.run_path(self.replicate) if self.verbose and get_platform().is_primary_process: print( '=' * 82 + f'\nTraining a Model with Knowledge Distillation (Replicate {self.replicate})\n' + '-' * 82) print(self.desc.display) print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '=' * 82 + '\n') if get_platform().is_primary_process: self.desc.save(location) # if get_platform().is_primary_process: self._establish_initial_weights() # get_platform().barrier() # Get the student model # student = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs) assert 'score-' in self.desc.model_hparams.model_name student = self._establish_initial_weights() # Get the teacher model teacher_model_hparams = deepcopy(self.desc.model_hparams) teacher_model_hparams.model_name = self.desc.distill_hparams.teacher_model_name teacher = models.registry.load_from_file( self.desc.distill_hparams.teacher_ckpt, teacher_model_hparams, self.desc.train_outputs) teacher_mask = Mask.load(self.desc.distill_hparams.teacher_mask) teacher = PrunedModel(teacher, teacher_mask) # Run training with knowledge distillation if models.registry.exists(location, self.desc.train_end_step, suffix='_distill'): student = models.registry.load(location, self.desc.train_end_step, self.desc.model_hparams, self.desc.train_outputs, suffix='_distill') else: train.distill_train(student, teacher, location, self.desc.dataset_hparams, self.desc.training_hparams, self.desc.distill_hparams, evaluate_every_epoch=self.evaluate_every_epoch, suffix='_distill') # Use the distilled student model to do the pruning student.apply_score_to_weight() # TODO: tweak pruning hparams to match teacher's sparsity level pruning.registry.get(self.desc.pruning_hparams)(student).save( location, suffix='_distill') # Train a new student model in the standard manner with the above mask new_student_model_hparams = deepcopy(self.desc.model_hparams) new_student_model_hparams.model_name = new_student_model_hparams.model_name.replace( 'score-', '') new_student = models.registry.load(location, self.desc.train_start_step, new_student_model_hparams, self.desc.train_outputs, strict=False, suffix='_score') new_student_mask = Mask.load(location, suffix='_distill') new_student = PrunedModel(new_student, new_student_mask) # Run standard training for the new student train.standard_train(new_student, location, self.desc.dataset_hparams, self.desc.training_hparams, start_step=self.desc.train_start_step, verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch, suffix='_post_distill')
def test_with_excess_mask_value(self): mask = Mask.ones_like(self.model) mask['layer2.weight'] = np.ones(20) with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def test_with_missing_mask_value(self): mask = Mask.ones_like(self.model) del mask['layer.weight'] with self.assertRaises(ValueError): PrunedModel(self.model, mask)
def test_mask_with_ones_forward(self): mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.assertEqual(45.0, pruned_model(self.example).item())
def test_load_state_dict(self): mask = Mask.ones_like(self.model) mask['layer.weight'] = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) self.model.layer.weight.data = torch.arange(10, 20, dtype=torch.float32) pruned_model = PrunedModel(self.model, mask) self.assertEqual(70.0, pruned_model(self.example).item()) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(67.0, pruned_model(self.example).item()) # Save the old state dict. state_dict = pruned_model.state_dict() # Create a new model. self.model = InnerProductModel(10) mask = Mask.ones_like(self.model) pruned_model = PrunedModel(self.model, mask) self.assertEqual(45.0, pruned_model(self.example).item()) # Load the state dict. pruned_model.load_state_dict(state_dict) self.assertEqual(67.0, pruned_model(self.example).item()) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) self.optimizer.zero_grad() pruned_model.loss_criterion(pruned_model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(64.3, np.round(pruned_model(self.example).item(), 1))
def branch_function(self, target_model_name: str = None, block_mapping: str = None, start_at_step_zero: bool = False, data_seed: int = 118): # Process the mapping # A valid string format of a mapping is like: # `0:0;1:1,2;2:3,4;3:5,6;4:7,8` if 'cifar' in target_model_name and 'resnet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'imagenet' in target_model_name and 'resnet' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'vggnfc' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'cifar' in target_model_name and 'vgg' in target_model_name: mappings = list( map(parse_block_mapping_for_stage, block_mapping.split('|'))) elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) elif 'mnist' in target_model_name and 'lenet' in target_model_name: mappings = parse_block_mapping_for_stage(block_mapping) else: raise NotImplementedError( 'Other mapping cases not implemented yet') # Load source model at `train_start_step` src_mask = Mask.load(self.level_root) start_step = self.lottery_desc.str_to_step( '0it' ) if start_at_step_zero else self.lottery_desc.train_start_step # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask) src_model = models.registry.load(self.level_root, start_step, self.lottery_desc.model_hparams) # Create target model target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams) target_model_hparams.model_name = target_model_name target_model = models.registry.get(target_model_hparams) target_ones_mask = Mask.ones_like(target_model) # Do the morphism target_sd = change_depth(target_model_name, src_model.state_dict(), target_model.state_dict(), mappings) target_model.load_state_dict(target_sd) target_mask = change_depth(target_model_name, src_mask, target_ones_mask, mappings) target_model_a = PrunedModel(target_model, target_mask) target_model_b = copy.deepcopy(target_model_a) # Save and run a standard train on model a seed_a = data_seed + 9999 training_hparams_a = copy.deepcopy(self.lottery_desc.training_hparams) training_hparams_a.data_order_seed = seed_a output_dir_a = os.path.join(self.branch_root, f'seed_{seed_a}') target_mask.save(output_dir_a) train.standard_train(target_model_a, output_dir_a, self.lottery_desc.dataset_hparams, training_hparams_a, start_step=start_step, verbose=self.verbose) # Save and run a standard train on model b seed_b = data_seed + 10001 training_hparams_b = copy.deepcopy(self.lottery_desc.training_hparams) training_hparams_b.data_order_seed = seed_b output_dir_b = os.path.join(self.branch_root, f'seed_{seed_b}') target_mask.save(output_dir_b) train.standard_train(target_model_b, output_dir_b, self.lottery_desc.dataset_hparams, training_hparams_b, start_step=start_step, verbose=self.verbose) # Linear connectivity between model_a and model_b training_hparams_c = copy.deepcopy(self.lottery_desc.training_hparams) training_hparams_c.training_steps = '1ep' for alpha in np.linspace(0, 1.0, 21): model_c = linear_interpolate(target_model_a, target_model_b, alpha) output_dir_c = os.path.join(self.branch_root, f'alpha_{alpha}') # Measure acc of model_c train.standard_train(model_c, output_dir_c, self.lottery_desc.dataset_hparams, training_hparams_c, start_step=None, verbose=self.verbose)