Exemple #1
0
    def forward(self, images, targets):
        teacher_model_without_dp =\
            self.teacher_model.module if isinstance(self.teacher_model, DataParallel) else self.teacher_model
        student_model_without_ddp = \
            self.student_model.module if isinstance(self.student_model, DistributedDataParallel) else self.student_model
        if self.require_adjustment:
            fixed_sizes = [
                random.choice(teacher_model_without_dp.transform.min_size)
                for _ in images
            ]
            self.teacher_model(images, fixed_sizes=fixed_sizes)
            org_loss_dict = self.student_model(images,
                                               targets,
                                               fixed_sizes=fixed_sizes)
        else:
            self.teacher_model(images)
            org_loss_dict = self.student_model(images, targets)

        output_dict = dict()
        for teacher_path, student_path in self.target_module_pairs:
            teacher_dict = module_util.get_module(
                teacher_model_without_dp,
                teacher_path).__dict__['distillation_box']
            student_dict = module_util.get_module(
                student_model_without_ddp,
                student_path).__dict__['distillation_box']
            output_dict[teacher_dict['loss_name']] = (
                (teacher_dict['path_from_root'], teacher_dict['output']),
                (student_dict['path_from_root'], student_dict['output']))

        total_loss = self.criterion(output_dict, org_loss_dict)
        return total_loss
Exemple #2
0
def freeze_modules(student_model, student_model_config):
    if 'frozen_modules' in student_model_config:
        for student_path in student_model_config['frozen_modules']:
            student_module = module_util.get_module(student_model, student_path)
            module_util.freeze_module_params(student_module)

    elif 'unfrozen_modules' in student_model_config:
        module_util.freeze_module_params(student_model)
        for student_path in student_model_config['unfrozen_modules']:
            student_module = module_util.get_module(student_model, student_path)
            module_util.unfreeze_module_params(student_module)
Exemple #3
0
    def forward(self, images, targets):
        teacher_model_without_dp = \
            self.teacher_model.module if isinstance(self.teacher_model, DataParallel) else self.teacher_model
        student_model_without_ddp = \
            self.student_model.module if isinstance(self.student_model, DistributedDataParallel) else self.student_model

        self.teacher_model(images)
        total_loss = 0

        train_widths = [1.0]
        if self.slimmable:
            width_mult_list = self.student_config['backbone']['params'][
                'width_mult_list']
            width_copies = self.student_config['backbone']['params'][
                'width_copies']
            width_mult_list.sort()
            train_widths = [width_mult_list[0]]
            for i in range(width_copies - 2):
                train_widths.append(random.choice(width_mult_list))
            train_widths.append(width_mult_list[-1])

        for width_mult in train_widths:
            set_width(self.student_model, width_mult)
            org_loss_dict = self.student_model(images, targets)
            output_dict = dict()
            for teacher_path, student_path in self.target_module_pairs:
                partial_teacher_module = module_util.get_module(
                    teacher_model_without_dp, teacher_path)
                partial_student_module = module_util.get_module(
                    student_model_without_ddp, student_path)
                teacher_dict = partial_teacher_module.__dict__[
                    'distillation_box']
                student_dict = partial_student_module.__dict__[
                    'distillation_box']
                output_dict[teacher_dict['loss_name']] = (
                    (teacher_dict['path_from_root'], teacher_dict['output']),
                    (student_dict['path_from_root'], student_dict['output']))

            total_loss += self.criterion(output_dict, org_loss_dict)

        # Free up memory
        for teacher_path, student_path in self.target_module_pairs:
            partial_teacher_module = module_util.get_module(
                teacher_model_without_dp, teacher_path)
            partial_student_module = module_util.get_module(
                student_model_without_ddp, student_path)
            teacher_dict = partial_teacher_module.__dict__['distillation_box']
            student_dict = partial_student_module.__dict__['distillation_box']
            partial_teacher_module.__dict__['distillation_box'][
                'output'] = None
            partial_student_module.__dict__['distillation_box'][
                'output'] = None

        return total_loss
Exemple #4
0
def freeze_modules(student_model, student_model_config, reset_unfrozen=False):
    if 'frozen_modules' in student_model_config:
        for student_path in student_model_config['frozen_modules']:
            student_module = module_util.get_module(student_model,
                                                    student_path)
            module_util.freeze_module_params(student_module)

    elif 'unfrozen_modules' in student_model_config:
        module_util.freeze_module_params(student_model)
        for student_path in student_model_config['unfrozen_modules']:
            student_module = module_util.get_module(student_model,
                                                    student_path)
            module_util.unfreeze_module_params(student_module)
            if reset_unfrozen:
                print("Reinitializing module: {}".format(student_path))
                init_weights(student_module)
Exemple #5
0
def get_trainable_modules(student_model, student_model_config):
    if 'frozen_modules' in student_model_config:
        all_modules = list(student_model.modules())
        modules = []
        for student_path in student_model_config['frozen_modules']:
            m = module_util.get_module(student_model, student_path)
            modules.append(module_util.get_module(student_model, student_path))

        trainable_modules = [m for m in all_modules if m not in modules]
        return trainable_modules

    elif 'unfrozen_modules' in student_model_config:
        trainable_modules = []
        for student_path in student_model_config['unfrozen_modules']:
            trainable_modules.append(
                module_util.get_module(student_model, student_path))
        return trainable_modules
Exemple #6
0
    def __init__(self, teacher_model, student_model, criterion_config):
        super().__init__()
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.target_module_pairs = list()

        def extract_output(self, input, output):
            self.__dict__['distillation_box']['output'] = output

        sub_terms_config = criterion_config.get('sub_terms', None)
        if sub_terms_config is not None:
            teacher_model_without_dp =\
                teacher_model.module if isinstance(teacher_model, DataParallel) else teacher_model
            student_model_without_ddp = \
                student_model.module if isinstance(student_model, DistributedDataParallel) else student_model
            for loss_name, loss_config in sub_terms_config.items():
                teacher_path, student_path = loss_config['ts_modules']
                self.target_module_pairs.append((teacher_path, student_path))
                teacher_module = module_util.get_module(
                    teacher_model_without_dp, teacher_path)
                student_module = module_util.get_module(
                    student_model_without_ddp, student_path)
                teacher_module.__dict__['distillation_box'] = {
                    'loss_name': loss_name,
                    'path_from_root': teacher_path,
                    'is_teacher': True
                }
                student_module.__dict__['distillation_box'] = {
                    'loss_name': loss_name,
                    'path_from_root': student_path,
                    'is_teacher': False
                }
                teacher_module.register_forward_hook(extract_output)
                student_module.register_forward_hook(extract_output)

        org_term_config = criterion_config['org_term']
        org_criterion_config = org_term_config['criterion']
        self.org_criterion = get_single_loss(org_criterion_config)
        self.org_factor = org_term_config['factor']
        self.criterion = get_custom_loss(criterion_config)
        self.use_teacher_output = isinstance(self.org_criterion, KDLoss)
Exemple #7
0
    def __init__(self, teacher_model, student_model, criterion_config,
                 student_config):
        super().__init__()
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.target_module_pairs = list()
        self.slimmable = ('slimmable' in student_config['backbone']['params'])
        self.student_config = student_config

        def extract_output(self, input, output):
            self.__dict__['distillation_box']['output'] = output

        teacher_model_without_dp = teacher_model.module if isinstance(
            teacher_model, DataParallel) else teacher_model
        student_model_without_ddp = \
            student_model.module if isinstance(student_model, DistributedDataParallel) else student_model
        for loss_name, loss_config in criterion_config['terms'].items():
            teacher_path, student_path = loss_config['ts_modules']
            self.target_module_pairs.append((teacher_path, student_path))
            teacher_module = module_util.get_module(teacher_model_without_dp,
                                                    teacher_path)
            student_module = module_util.get_module(student_model_without_ddp,
                                                    student_path)
            teacher_module.__dict__['distillation_box'] = {
                'loss_name': loss_name,
                'path_from_root': teacher_path,
                'is_teacher': True
            }
            student_module.__dict__['distillation_box'] = {
                'loss_name': loss_name,
                'path_from_root': student_path,
                'is_teacher': False
            }
            teacher_module.register_forward_hook(extract_output)
            student_module.register_forward_hook(extract_output)

        self.criterion = get_loss(criterion_config)
        self.require_adjustment = isinstance(student_model_without_ddp,
                                             KeypointRCNN)
Exemple #8
0
    def forward(self, sample_batch, targets):
        teacher_outputs = self.teacher_model(sample_batch)
        student_outputs = self.student_model(sample_batch)
        # Model with auxiliary classifier returns multiple outputs
        if isinstance(student_outputs, (list, tuple)):
            org_loss_dict = dict()
            if self.use_teacher_output:
                for i, sub_student_outputs, sub_teacher_outputs in enumerate(
                        zip(student_outputs, teacher_outputs)):
                    org_loss_dict[i] = self.org_criterion(
                        sub_student_outputs, sub_teacher_outputs, targets)
            else:
                for i, sub_outputs in enumerate(student_outputs):
                    org_loss_dict[i] = self.org_criterion(sub_outputs, targets)
        else:
            org_loss = self.org_criterion(student_outputs, teacher_outputs, targets) if self.use_teacher_output\
                else self.org_criterion(student_outputs, targets)
            org_loss_dict = {0: org_loss}

        output_dict = dict()
        teacher_model_without_dp = \
            self.teacher_model.module if isinstance(self.teacher_model, DataParallel) else self.teacher_model
        student_model_without_ddp = \
            self.student_model.module if isinstance(self.student_model, DistributedDataParallel) else self.student_model
        for teacher_path, student_path in self.target_module_pairs:
            teacher_dict = module_util.get_module(
                teacher_model_without_dp,
                teacher_path).__dict__['distillation_box']
            student_dict = module_util.get_module(
                student_model_without_ddp,
                student_path).__dict__['distillation_box']
            output_dict[teacher_dict['loss_name']] = (
                (teacher_dict['path_from_root'], teacher_dict['output']),
                (student_dict['path_from_root'], student_dict['output']))

        total_loss = self.criterion(output_dict, org_loss_dict)
        return total_loss
Exemple #9
0
def freeze_modules(student_model, student_model_config):
    for student_path in student_model_config['frozen_modules']:
        student_module = module_util.get_module(student_model, student_path)
        module_util.freeze_module_params(student_module)