def setup_teacher_student_models(self, teacher_config, student_config):
        unwrapped_org_teacher_model =\
            self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model
        unwrapped_org_student_model = \
            self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model
        self.target_teacher_pairs.clear()
        self.target_student_pairs.clear()
        teacher_ref_model = unwrapped_org_teacher_model
        student_ref_model = unwrapped_org_student_model
        if len(teacher_config) > 0 or (len(teacher_config) == 0
                                       and self.teacher_model is None):
            model_type = 'original'
            special_teacher_model = \
                build_special_module(teacher_config, teacher_model=unwrapped_org_teacher_model, device=self.device,
                                     device_ids=self.device_ids, distributed=self.distributed)
            if special_teacher_model is not None:
                teacher_ref_model = special_teacher_model
                model_type = type(teacher_ref_model).__name__
            self.teacher_model = redesign_model(teacher_ref_model,
                                                teacher_config, 'teacher',
                                                model_type)

        if len(student_config) > 0 or (len(student_config) == 0
                                       and self.student_model is None):
            model_type = 'original'
            special_student_model = \
                build_special_module(student_config, student_model=unwrapped_org_student_model, device=self.device,
                                     device_ids=self.device_ids, distributed=self.distributed)
            if special_student_model is not None:
                student_ref_model = special_student_model
                model_type = type(student_ref_model).__name__
            self.student_model = redesign_model(student_ref_model,
                                                student_config, 'student',
                                                model_type)

        self.teacher_any_frozen = \
            len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True)
        self.student_any_frozen = \
            len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True)
        self.target_teacher_pairs.extend(
            set_hooks(self.teacher_model, teacher_ref_model, teacher_config,
                      self.teacher_io_dict))
        self.target_student_pairs.extend(
            set_hooks(self.student_model, student_ref_model, student_config,
                      self.student_io_dict))
        self.teacher_forward_proc = get_forward_proc_func(
            teacher_config.get('forward_proc', None))
        self.student_forward_proc = get_forward_proc_func(
            student_config.get('forward_proc', None))
    def __init__(self, teacher_model, minimal, input_module_path,
                 paraphraser_params, paraphraser_ckpt, uses_decoder, device,
                 device_ids, distributed, **kwargs):
        super().__init__()
        if minimal is None:
            minimal = dict()

        special_teacher_model = build_special_module(
            minimal, teacher_model=teacher_model)
        model_type = 'original'
        teacher_ref_model = teacher_model
        if special_teacher_model is not None:
            teacher_ref_model = special_teacher_model
            model_type = type(teacher_ref_model).__name__

        self.teacher_model = redesign_model(teacher_ref_model, minimal,
                                            'teacher', model_type)
        self.input_module_path = input_module_path
        self.paraphraser = \
            wrap_if_distributed(Paraphraser4FactorTransfer(**paraphraser_params), device, device_ids, distributed)
        self.ckpt_file_path = paraphraser_ckpt
        if os.path.isfile(self.ckpt_file_path):
            map_location = {
                'cuda:0': 'cuda:{}'.format(device_ids[0])
            } if distributed else device
            load_module_ckpt(self.paraphraser, map_location,
                             self.ckpt_file_path)
        self.uses_decoder = uses_decoder
Exemple #3
0
    def __init__(self, head_rcnn, **kwargs):
        super().__init__()
        tmp_ref_model = kwargs.get('teacher_model', None)
        ref_model = kwargs.get('student_model', tmp_ref_model)
        if ref_model is None:
            raise ValueError('Either student_model or teacher_model has to be given.')

        self.transform = ref_model.transform
        self.seq = redesign_model(ref_model, head_rcnn, 'R-CNN', 'HeadRCNN')