def setup_teacher_student_models(self, teacher_config, student_config): unwrapped_org_teacher_model =\ self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model unwrapped_org_student_model = \ self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model self.target_teacher_pairs.clear() self.target_student_pairs.clear() teacher_ref_model = unwrapped_org_teacher_model student_ref_model = unwrapped_org_student_model if len(teacher_config) > 0 or (len(teacher_config) == 0 and self.teacher_model is None): model_type = 'original' special_teacher_model = \ build_special_module(teacher_config, teacher_model=unwrapped_org_teacher_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if special_teacher_model is not None: teacher_ref_model = special_teacher_model model_type = type(teacher_ref_model).__name__ self.teacher_model = redesign_model(teacher_ref_model, teacher_config, 'teacher', model_type) if len(student_config) > 0 or (len(student_config) == 0 and self.student_model is None): model_type = 'original' special_student_model = \ build_special_module(student_config, student_model=unwrapped_org_student_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if special_student_model is not None: student_ref_model = special_student_model model_type = type(student_ref_model).__name__ self.student_model = redesign_model(student_ref_model, student_config, 'student', model_type) self.teacher_any_frozen = \ len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True) self.student_any_frozen = \ len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True) self.target_teacher_pairs.extend( set_hooks(self.teacher_model, teacher_ref_model, teacher_config, self.teacher_io_dict)) self.target_student_pairs.extend( set_hooks(self.student_model, student_ref_model, student_config, self.student_io_dict)) self.teacher_forward_proc = get_forward_proc_func( teacher_config.get('forward_proc', None)) self.student_forward_proc = get_forward_proc_func( student_config.get('forward_proc', None))
def __init__(self, teacher_model, minimal, input_module_path, paraphraser_params, paraphraser_ckpt, uses_decoder, device, device_ids, distributed, **kwargs): super().__init__() if minimal is None: minimal = dict() special_teacher_model = build_special_module( minimal, teacher_model=teacher_model) model_type = 'original' teacher_ref_model = teacher_model if special_teacher_model is not None: teacher_ref_model = special_teacher_model model_type = type(teacher_ref_model).__name__ self.teacher_model = redesign_model(teacher_ref_model, minimal, 'teacher', model_type) self.input_module_path = input_module_path self.paraphraser = \ wrap_if_distributed(Paraphraser4FactorTransfer(**paraphraser_params), device, device_ids, distributed) self.ckpt_file_path = paraphraser_ckpt if os.path.isfile(self.ckpt_file_path): map_location = { 'cuda:0': 'cuda:{}'.format(device_ids[0]) } if distributed else device load_module_ckpt(self.paraphraser, map_location, self.ckpt_file_path) self.uses_decoder = uses_decoder
def __init__(self, head_rcnn, **kwargs): super().__init__() tmp_ref_model = kwargs.get('teacher_model', None) ref_model = kwargs.get('student_model', tmp_ref_model) if ref_model is None: raise ValueError('Either student_model or teacher_model has to be given.') self.transform = ref_model.transform self.seq = redesign_model(ref_model, head_rcnn, 'R-CNN', 'HeadRCNN')