Пример #1
0
 def branch_function(self,
                     retrain_d: hparams.DatasetHparams,
                     retrain_t: hparams.TrainingHparams,
                     start_at_step_zero: bool = False,
                     transfer_learn: bool = False):
     # Get the mask and model.
     if transfer_learn:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_end_step,
                                  self.lottery_desc.model_hparams)
     else:
         m = models.registry.load(self.level_root,
                                  self.lottery_desc.train_start_step,
                                  self.lottery_desc.model_hparams)
     m = PrunedModel(m, Mask.load(self.level_root))
     start_step = Step.from_iteration(
         0 if start_at_step_zero else
         self.lottery_desc.train_start_step.iteration,
         datasets.registry.iterations_per_epoch(retrain_d))
     train.standard_train(m,
                          self.branch_root,
                          retrain_d,
                          retrain_t,
                          start_step=start_step,
                          verbose=self.verbose)
Пример #2
0
    def _pretrain(self):
        location = self.desc.run_path(self.replicate, 'pretrain')
        if models.registry.exists(location, self.desc.pretrain_end_step): return

        if self.verbose and get_platform().is_primary_process: print('-'*82 + '\nPretraining\n' + '-'*82)
        model = models.registry.get(self.desc.model_hparams, outputs=self.desc.pretrain_outputs)
        train.standard_train(model, location, self.desc.pretrain_dataset_hparams, self.desc.pretrain_training_hparams,
                             verbose=self.verbose, evaluate_every_epoch=self.evaluate_every_epoch,
                             weight_save_steps=self.weight_save_steps)
Пример #3
0
    def branch_function(self,
                        target_model_name: str = None,
                        block_mapping: str = None,
                        start_at_step_zero: bool = False):
        # Process the mapping
        # A valid string format of a mapping is like:
        #   `0:0;1:1,2;2:3,4;3:5,6;4:7,8`
        if 'cifar' in target_model_name and 'resnet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'imagenet' in target_model_name and 'resnet' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'vggnfc' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'cifar' in target_model_name and 'vgg' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'mnist' in target_model_name and 'lenet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        else:
            raise NotImplementedError(
                'Other mapping cases not implemented yet')

        # Load source model at `train_start_step`
        src_mask = Mask.load(self.level_root)
        start_step = self.lottery_desc.str_to_step(
            '0it'
        ) if start_at_step_zero else self.lottery_desc.train_start_step
        # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask)
        src_model = models.registry.load(self.level_root, start_step,
                                         self.lottery_desc.model_hparams)

        # Create target model
        target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams)
        target_model_hparams.model_name = target_model_name
        target_model = models.registry.get(target_model_hparams)
        target_ones_mask = Mask.ones_like(target_model)

        # Do the morphism
        target_sd = change_depth(target_model_name, src_model.state_dict(),
                                 target_model.state_dict(), mappings)
        target_model.load_state_dict(target_sd)
        target_mask = change_depth(target_model_name, src_mask,
                                   target_ones_mask, mappings)
        target_model = PrunedModel(target_model, target_mask)

        # Save and run a standard train
        target_mask.save(self.branch_root)
        train.standard_train(target_model,
                             self.branch_root,
                             self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose)
Пример #4
0
    def _train_level(self, level: int):
        location = self.desc.run_path(self.replicate, level)
        if models.registry.exists(location, self.desc.train_end_step): return

        model = models.registry.load(self.desc.run_path(self.replicate, 0), self.desc.train_start_step,
                                     self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(location))
        pruned_model.save(location, self.desc.train_start_step)
        if self.verbose and get_platform().is_primary_process:
            print('-'*82 + '\nPruning Level {}\n'.format(level) + '-'*82)
        train.standard_train(pruned_model, location, self.desc.dataset_hparams, self.desc.training_hparams,
                             start_step=self.desc.train_start_step, verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch, weight_save_steps=self.weight_save_steps)
Пример #5
0
 def run(self):
     if self.verbose and get_platform().is_primary_process:
         print('=' * 82 +
               f'\nTraining a Model (Replicate {self.replicate})\n' +
               '-' * 82)
         print(self.desc.display)
         print(f'Output Location: {self.desc.run_path(self.replicate)}' +
               '\n' + '=' * 82 + '\n')
     self.desc.save(self.desc.run_path(self.replicate))
     train.standard_train(models.registry.get(self.desc.model_hparams),
                          self.desc.run_path(self.replicate),
                          self.desc.dataset_hparams,
                          self.desc.training_hparams,
                          evaluate_every_epoch=self.evaluate_every_epoch)
Пример #6
0
 def branch_function(self, start_at_step_zero: bool = False):
     model = PrunedModel(
         models.registry.get(self.lottery_desc.model_hparams),
         Mask.load(self.level_root))
     start_step = self.lottery_desc.str_to_step(
         '0it'
     ) if start_at_step_zero else self.lottery_desc.train_start_step
     Mask.load(self.level_root).save(self.branch_root)
     train.standard_train(model,
                          self.branch_root,
                          self.lottery_desc.dataset_hparams,
                          self.lottery_desc.training_hparams,
                          start_step=start_step,
                          verbose=self.verbose)
Пример #7
0
    def branch_function(self, seed: int, strategy: str = 'layerwise', start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Randomize the mask.
        mask = Mask.load(self.level_root)

        # Randomize while keeping the same layerwise proportions as the original mask.
        if strategy == 'layerwise': mask = Mask(shuffle_state_dict(mask, seed=seed))

        # Randomize globally throughout all prunable layers.
        elif strategy == 'global': mask = Mask(unvectorize(shuffle_tensor(vectorize(mask), seed=seed), mask))

        # Randomize evenly across all layers.
        elif strategy == 'even':
            sparsity = mask.sparsity
            for i, k in sorted(mask.keys()):
                layer_mask = torch.where(torch.arange(mask[k].size) < torch.ceil(sparsity * mask[k].size),
                                         torch.ones_like(mask[k].size), torch.zeros_like(mask[k].size))
                mask[k] = shuffle_tensor(layer_mask, seed=seed+i).reshape(mask[k].size)

        # Identity.
        elif strategy == 'identity': pass

        # Error.
        else: raise ValueError(f'Invalid strategy: {strategy}')

        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Save the new mask.
        mask.save(self.branch_root)

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = PrunedModel(models.registry.load(self.level_root, state_step, self.lottery_desc.model_hparams), mask)
        train.standard_train(model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
Пример #8
0
    def run(self):
        if self.verbose and get_platform().is_primary_process:
            print('='*82 + f'\nTraining a Model (Replicate {self.replicate})\n' + '-'*82)
            print(self.desc.display)
            print(f'Output Location: {self.desc.run_path(self.replicate)}' + '\n' + '='*82 + '\n')
        self.desc.save(self.desc.run_path(self.replicate))

        #TODO: make mask and model init paths configurable
        init_path = os.path.join(get_platform().root, 'resnet18_lth')
        model = models.registry.load(init_path, Step.from_str('2ep218it', 1000),
                                    self.desc.model_hparams, self.desc.train_outputs)
        pruned_model = PrunedModel(model, Mask.load(init_path))

        train.standard_train(
            #models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate),
            pruned_model, self.desc.run_path(self.replicate),
            self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
Пример #9
0
    def branch_function(self,
                        seed: int,
                        strategy: str = 'sparse_global',
                        start_at: str = 'rewind',
                        layers_to_ignore: str = ''):
        # Reset the masks of any layers that shouldn't be pruned.
        if layers_to_ignore:
            for k in layers_to_ignore.split(','): mask[k] = torch.ones_like(mask[k])

        # Determine the start step.
        if start_at == 'init':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = start_step
        elif start_at == 'end':
            start_step = self.lottery_desc.str_to_step('0ep')
            state_step = self.lottery_desc.train_end_step
        elif start_at == 'rewind':
            start_step = self.lottery_desc.train_start_step
            state_step = start_step
        else:
            raise ValueError(f'Invalid starting point {start_at}')

        # Train the model with the new mask.
        model = models.registry.load(self.pretrain_root, state_step, self.lottery_desc.model_hparams)

        # Get the current level mask and get the target pruning ratio
        mask = Mask.load(self.level_root)
        sparsity_ratio = mask.get_sparsity_ratio()
        target_pruning_fraction = 1.0 - sparsity_ratio

        # Run pruning
        pruning_hparams = copy.deepcopy(self.lottery_desc.pruning_hparams)
        pruning_hparams.pruning_strategy = strategy
        pruning_hparams.pruning_fraction = target_pruning_fraction
        new_mask = pruning.registry.get(pruning_hparams)(
            model, Mask.ones_like(model),
            self.lottery_desc.training_hparams,
            self.lottery_desc.dataset_hparams, seed
        )
        new_mask.save(self.branch_root)

        repruned_model = PrunedModel(model.to(device=get_platform().cpu_device), new_mask)

        # Run training
        train.standard_train(repruned_model, self.branch_root, self.lottery_desc.dataset_hparams,
                             self.lottery_desc.training_hparams, start_step=start_step, verbose=self.verbose)
    def branch_function(self,
                        target_model_name: str = None,
                        block_mapping: str = None,
                        start_at_step_zero: bool = False,
                        data_seed: int = 118):
        # Process the mapping
        # A valid string format of a mapping is like:
        #   `0:0;1:1,2;2:3,4;3:5,6;4:7,8`
        if 'cifar' in target_model_name and 'resnet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'imagenet' in target_model_name and 'resnet' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'vggnfc' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'cifar' in target_model_name and 'vgg' in target_model_name:
            mappings = list(
                map(parse_block_mapping_for_stage, block_mapping.split('|')))
        elif 'cifar' in target_model_name and 'mobilenetv1' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        elif 'mnist' in target_model_name and 'lenet' in target_model_name:
            mappings = parse_block_mapping_for_stage(block_mapping)
        else:
            raise NotImplementedError(
                'Other mapping cases not implemented yet')

        # Load source model at `train_start_step`
        src_mask = Mask.load(self.level_root)
        start_step = self.lottery_desc.str_to_step(
            '0it'
        ) if start_at_step_zero else self.lottery_desc.train_start_step
        # model = PrunedModel(models.registry.get(self.lottery_desc.model_hparams), src_mask)
        src_model = models.registry.load(self.level_root, start_step,
                                         self.lottery_desc.model_hparams)

        # Create target model
        target_model_hparams = copy.deepcopy(self.lottery_desc.model_hparams)
        target_model_hparams.model_name = target_model_name
        target_model = models.registry.get(target_model_hparams)
        target_ones_mask = Mask.ones_like(target_model)

        # Do the morphism
        target_sd = change_depth(target_model_name, src_model.state_dict(),
                                 target_model.state_dict(), mappings)
        target_model.load_state_dict(target_sd)
        target_mask = change_depth(target_model_name, src_mask,
                                   target_ones_mask, mappings)
        target_model_a = PrunedModel(target_model, target_mask)
        target_model_b = copy.deepcopy(target_model_a)

        # Save and run a standard train on model a
        seed_a = data_seed + 9999
        training_hparams_a = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_a.data_order_seed = seed_a
        output_dir_a = os.path.join(self.branch_root, f'seed_{seed_a}')
        target_mask.save(output_dir_a)
        train.standard_train(target_model_a,
                             output_dir_a,
                             self.lottery_desc.dataset_hparams,
                             training_hparams_a,
                             start_step=start_step,
                             verbose=self.verbose)

        # Save and run a standard train on model b
        seed_b = data_seed + 10001
        training_hparams_b = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_b.data_order_seed = seed_b
        output_dir_b = os.path.join(self.branch_root, f'seed_{seed_b}')
        target_mask.save(output_dir_b)
        train.standard_train(target_model_b,
                             output_dir_b,
                             self.lottery_desc.dataset_hparams,
                             training_hparams_b,
                             start_step=start_step,
                             verbose=self.verbose)

        # Linear connectivity between model_a and model_b
        training_hparams_c = copy.deepcopy(self.lottery_desc.training_hparams)
        training_hparams_c.training_steps = '1ep'
        for alpha in np.linspace(0, 1.0, 21):
            model_c = linear_interpolate(target_model_a, target_model_b, alpha)
            output_dir_c = os.path.join(self.branch_root, f'alpha_{alpha}')
            # Measure acc of model_c
            train.standard_train(model_c,
                                 output_dir_c,
                                 self.lottery_desc.dataset_hparams,
                                 training_hparams_c,
                                 start_step=None,
                                 verbose=self.verbose)
Пример #11
0
    def run(self):
        location = self.desc.run_path(self.replicate)
        if self.verbose and get_platform().is_primary_process:
            print(
                '=' * 82 +
                f'\nTraining a Model with Knowledge Distillation (Replicate {self.replicate})\n'
                + '-' * 82)
            print(self.desc.display)
            print(f'Output Location: {self.desc.run_path(self.replicate)}' +
                  '\n' + '=' * 82 + '\n')

        if get_platform().is_primary_process: self.desc.save(location)

        # if get_platform().is_primary_process: self._establish_initial_weights()
        # get_platform().barrier()

        # Get the student model
        # student = models.registry.get(self.desc.model_hparams, outputs=self.desc.train_outputs)
        assert 'score-' in self.desc.model_hparams.model_name
        student = self._establish_initial_weights()

        # Get the teacher model
        teacher_model_hparams = deepcopy(self.desc.model_hparams)
        teacher_model_hparams.model_name = self.desc.distill_hparams.teacher_model_name
        teacher = models.registry.load_from_file(
            self.desc.distill_hparams.teacher_ckpt, teacher_model_hparams,
            self.desc.train_outputs)
        teacher_mask = Mask.load(self.desc.distill_hparams.teacher_mask)
        teacher = PrunedModel(teacher, teacher_mask)

        # Run training with knowledge distillation
        if models.registry.exists(location,
                                  self.desc.train_end_step,
                                  suffix='_distill'):
            student = models.registry.load(location,
                                           self.desc.train_end_step,
                                           self.desc.model_hparams,
                                           self.desc.train_outputs,
                                           suffix='_distill')
        else:
            train.distill_train(student,
                                teacher,
                                location,
                                self.desc.dataset_hparams,
                                self.desc.training_hparams,
                                self.desc.distill_hparams,
                                evaluate_every_epoch=self.evaluate_every_epoch,
                                suffix='_distill')

        # Use the distilled student model to do the pruning
        student.apply_score_to_weight()
        # TODO: tweak pruning hparams to match teacher's sparsity level
        pruning.registry.get(self.desc.pruning_hparams)(student).save(
            location, suffix='_distill')

        # Train a new student model in the standard manner with the above mask
        new_student_model_hparams = deepcopy(self.desc.model_hparams)
        new_student_model_hparams.model_name = new_student_model_hparams.model_name.replace(
            'score-', '')
        new_student = models.registry.load(location,
                                           self.desc.train_start_step,
                                           new_student_model_hparams,
                                           self.desc.train_outputs,
                                           strict=False,
                                           suffix='_score')
        new_student_mask = Mask.load(location, suffix='_distill')
        new_student = PrunedModel(new_student, new_student_mask)

        # Run standard training for the new student
        train.standard_train(new_student,
                             location,
                             self.desc.dataset_hparams,
                             self.desc.training_hparams,
                             start_step=self.desc.train_start_step,
                             verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch,
                             suffix='_post_distill')
Пример #12
0
    def branch_function(
            self,
            strategy: str,
            prune_fraction: float,
            prune_experiment: str = 'main',
            prune_step: str = '0ep0it',  # The step for states used to prune.
            prune_highest: bool = False,
            prune_iterations: int = 1,
            randomize_layerwise: bool = False,
            state_experiment: str = 'main',
            state_step:
        str = '0ep0it',  # The step of the state to use alongside the pruning mask.
            start_step:
        str = '0ep0it',  # The step at which to start the learning rate schedule.
            seed: int = None,
            reinitialize: bool = False):
        # Get the steps for each part of the process.
        iterations_per_epoch = datasets.registry.iterations_per_epoch(
            self.desc.dataset_hparams)
        prune_step = Step.from_str(prune_step, iterations_per_epoch)
        state_step = Step.from_str(state_step, iterations_per_epoch)
        start_step = Step.from_str(start_step, iterations_per_epoch)
        seed = self.replicate if seed is None else seed

        # Try to load the mask.
        try:
            mask = Mask.load(self.branch_root)
        except:
            mask = None

        result_folder = "Data_Distribution/"
        if reinitialize:
            result_folder = "Data_Distribution_Reinit/"
        elif randomize_layerwise:
            result_folder = "Data_Distribution__Randomize_Layerwise/"

        if not mask and get_platform().is_primary_process:
            # Gather the weights that will be used for pruning.
            prune_path = self.desc.run_path(self.replicate, prune_experiment)
            prune_model = models.registry.load(prune_path, prune_step,
                                               self.desc.model_hparams)

            # Ensure that a valid strategy is available.
            strategy_class = [s for s in strategies if s.valid_name(strategy)]
            if not strategy_class:
                raise ValueError(f'No such pruning strategy {strategy}')
            if len(strategy_class) > 1:
                raise ValueError('Multiple matching strategies')
            strategy_instance = strategy_class[0](strategy, self.desc, seed)

            # Run the strategy for each iteration.
            mask = Mask.ones_like(prune_model)
            iteration_fraction = 1 - (1 - prune_fraction)**(
                1 / float(prune_iterations))

            if iteration_fraction > 0:
                for it in range(0, prune_iterations):
                    # Make a defensive copy of the model and mask out the pruned weights.
                    prune_model2 = copy.deepcopy(prune_model)
                    with torch.no_grad():
                        for k, v in prune_model2.named_parameters():
                            v.mul_(mask.get(k, 1))

                    # Compute the scores.
                    scores = strategy_instance.score(prune_model2, mask)

                    # Prune.
                    mask = unvectorize(
                        prune(vectorize(scores),
                              iteration_fraction,
                              not prune_highest,
                              mask=vectorize(mask)), mask)

            # Shuffle randomly per layer.
            if randomize_layerwise: mask = shuffle_state_dict(mask, seed=seed)

            mask = Mask({k: v.clone().detach() for k, v in mask.items()})
            mask.save(self.branch_root)

            # Plot graphs (Move below mask save?)

            # plot_distribution_scores(strategy_instance.score(prune_model, mask), strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)
            # plot_distribution_scatter(strategy_instance.score(prune_model, mask), prune_model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

            # pdb.set_trace()

        # Load the mask.
        get_platform().barrier()
        mask = Mask.load(self.branch_root)

        # Determine the start step.
        state_path = self.desc.run_path(self.replicate, state_experiment)
        if reinitialize: model = models.registry.get(self.desc.model_hparams)
        else:
            model = models.registry.load(state_path, state_step,
                                         self.desc.model_hparams)

        # plot_distribution_weights(model, strategy, mask, prune_iterations, reinitialize, randomize_layerwise, result_folder)

        original_model = copy.deepcopy(model)
        model = PrunedModel(model, mask)

        # pdb.set_trace()
        train.standard_train(model,
                             self.branch_root,
                             self.desc.dataset_hparams,
                             self.desc.training_hparams,
                             start_step=start_step,
                             verbose=self.verbose,
                             evaluate_every_epoch=self.evaluate_every_epoch)

        weights_analysis(original_model, strategy, reinitialize,
                         randomize_layerwise, "original")
        weights_analysis(model, strategy, reinitialize, randomize_layerwise,
                         "pruned")
Пример #13
0
    def _train_level(self, level: int):
        location = self.desc.run_path(self.replicate, level)  # 해당 level path

        if models.registry.exists(location, self.desc.train_end_step):

            # image PATH 가 없으면 make directory => 내가 추가한것.
            if not os.path.exists(f'{location}\Distribution_of_Weight'):
                os.makedirs(f'{location}\Distribution_of_Weight')

                IMAGE_PATH = f'{location}\Distribution_of_Weight'
                # Weight Load & Weight Plotting => 내가 추가

                print('\nPlotting Location is: ', location)
                model = models.registry.load(
                    self.desc.run_path(self.replicate,
                                       0), self.desc.train_start_step,
                    self.desc.model_hparams, self.desc.train_outputs)

                # Load Original_Save Parameter : As the batch size changes, the ep should be adjusted. default: batch_size=16
                for ep, iteration in [[0, 0], [149, 234]]:
                    model.load_state_dict(
                        torch.load('{}\model_ep{}_it{}.pth'.format(
                            location, ep, iteration)))
                    model.eval()
                    print("\nmodel_ep{}_it{}.pth".format(ep, iteration))

                    for param_tensor in model.state_dict():
                        #print(param_tensor, "\n", model.state_dict()[param_tensor])
                        tensor = model.state_dict()[param_tensor]
                        tensor = tensor.numpy()

                        #tensor에서 weight 만 추출
                        tensor = tensor[0]
                        #print(tensor)
                        tensor = tensor.reshape(-1)
                        sns.kdeplot(tensor)

                        plt.savefig(
                            '{}\Distribution_of_weights_level{}_ep{}.png'.
                            format(IMAGE_PATH, level, ep))
                    plt.clf()
            """
            # Load Weights before & after Training => 내가 추가
            for ep in range(14):
                for Label in ["Before","After"]:
                    print("\n {} Training \n".format(Label))
                    model.load_state_dict(torch.load('{}\weights\Record_Weights_{}_ep{}.pth'.format(location,Label, ep)),strict = False)
                    model.eval()

                    for param_tensor in model.state_dict():
                        print(param_tensor, "\t", model.state_dict()[param_tensor])
                        #print(param_tensor, "\t", model.state_dict()[param_tensor].size())
            """

            return

        # 만약 트레이닝을 전에 시켰다면 위에서 return 해서 끝나버려서 여기까지 안옴.
        model = models.registry.load(self.desc.run_path(self.replicate, 0),
                                     self.desc.train_start_step,
                                     self.desc.model_hparams,
                                     self.desc.train_outputs)
        # level 7일때부터(즉 21 % 남았을때 부터)는 double_param_level = True 이면 초기화시 masking 후 2배 적용,
        # layer_differnt = True 이면 layer 별로 masking 한후 각기 다르게 상수배 해줌.
        if level >= 7:
            pruned_model = PrunedModel(model,
                                       Mask.load(location),
                                       double_param_level=False,
                                       layer_different=True)
        else:
            pruned_model = PrunedModel(model,
                                       Mask.load(location))  # model, mask 불러오기

        pruned_model.save(location,
                          self.desc.train_start_step)  # pruned된 모델 저장
        #print(f'Prunded Model is: {PrunedModel(model,Mask.load(location))}\n')
        #print(f'Mask.load is: {Mask.load(location)}')
        if self.verbose and get_platform().is_primary_process:
            print('-' * 82 + '\nPruning Level {}\n'.format(level) +
                  '-' * 82)  # level = 0, level = 1... 등 pruning level 표시
        """
        # 내가 추가 (tensor 다루는법이 있어서 남겨둠)
        if level > 7:
            pruned_model.eval()
            for param_tensor in pruned_model.state_dict():
                tensor = pruned_model.state_dict()[param_tensor]
                tensor = tensor.numpy()
                tensor = tensor * 2
                pruned_model.state_dict()[param_tensor] = tensor
            pruned_model.save(location, self.desc.train_start_step)
        """

        train.standard_train(
            pruned_model,
            location,
            self.desc.dataset_hparams,
            self.desc.training_hparams,
            start_step=self.desc.train_start_step,
            verbose=self.verbose,
            evaluate_every_epoch=self.evaluate_every_epoch
        )  # training 하기 => trainig 할때는 trian_start_step의 parameter 사용